diff --git a/README.md b/README.md index 4c3dc9ac..06368937 100644 --- a/README.md +++ b/README.md @@ -7,13 +7,13 @@

- + license - + python versions - + pypi version @@ -29,7 +29,7 @@

- Visit the docs for more info, or browse the releases. + Visit the docs for more info, or browse the releases.

Contributions are welcome! @@ -42,10 +42,11 @@ - [Overview](#overview) - [Features](#features) + * [Datasets](#datasets) * [Frameworks](#frameworks) * [Metrics](#metrics) - * [Datasets](#datasets) * [Schedules & Annealing](#schedules--annealing) +- [Architecture](#architecture) - [Examples](#examples) * [Python Example](#python-example) * [Hydra Config Example](#hydra-config-example) @@ -88,65 +89,94 @@ Please use the following citation if you use Disent in your own research: ---------------------- -## Architecture - -The disent module structure: +## Features -- `disent.dataset`: dataset wrappers, datasets & sampling strategies - + `disent.dataset.data`: raw datasets - + `disent.dataset.sampling`: sampling strategies for `DisentDataset` when multiple elements are required by frameworks, eg. for triplet loss - + `disent.dataset.transform`: common data transforms and augmentations - + `disent.dataset.wrapper`: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking. -- `disent.frameworks`: frameworks, including Auto-Encoders and VAEs - + `disent.frameworks.ae`: Auto-Encoder based frameworks - + `disent.frameworks.vae`: Variational Auto-Encoder based frameworks -- `disent.metrics`: metrics for evaluating disentanglement using ground truth datasets -- `disent.model`: common encoder and decoder models used for VAE research -- `disent.nn`: torch components for building models including layers, transforms, losses and general maths -- `disent.schedule`: annealing schedules that can be registered to a framework -- `disent.util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework. +Disent includes implementations of modules, metrics and +datasets from various papers. -**Please Note The API Is Still Unstable โš ๏ธ** +_Note that "๐Ÿงต" means that the dataset, framework or metric was introduced by disent!_ -Disent is still under active development. Features and APIs are mostly stable but may change! A limited -set of tests currently exist which will be expanded upon in time. +### Datasets -**Hydra Experiment Directories** +Various common datasets used in disentanglement research are included, with hash +verification and automatic chunk-size optimization of underlying hdf5 formats for +low-memory disk-based access. -Easily run experiments with hydra config, these files -are not available from `pip install`. +Data input and target dataset augmentations and transforms are supported, as well as augmentations +on the GPU or CPU at different points in the pipeline. -- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config -- `experiment/config/config.yaml`: main configuration file, this is probably what you want to edit! -- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files -- `experiment/util`: various helper code for experiments +- **Ground Truth**: + +

+ ๐Ÿš— Cars3D +

Cars3D Dataset Factor Traversals

+
+ + +
+ โ—ป๏ธ dSprites +

dSprites Dataset Factor Traversals

+
+ + +
+ ๐Ÿ”บ MPI3D +

๐Ÿ— Todo

+
+ + +
+ ๐Ÿ˜ SmallNORB +

Small Norb Dataset Factor Traversals

+
+ + +
+ ๐ŸŒˆ Shapes3D +

Shapes3D Dataset Factor Traversals

+
+ + +
+ + ๐Ÿงต dSpritesImagenet: + Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data. + +

dSpritesImagenet Dataset Factor Traversals

+
----------------------- +- **Ground Truth Synthetic**: + +
+ + ๐Ÿงต XYObject: + A simplistic version of dSprites with a single square. + +

XYObject Dataset Factor Traversals

+
+ + +
+ + ๐Ÿงต XYObjectShaded: + Exact same dataset as XYObject, but ground truth factors have a different representation. + +

XYObjectShaded Dataset Factor Traversals

+
-## Features +### Frameworks -Disent includes implementations of modules, metrics and -datasets from various papers. Please note that items marked - with a "๐Ÿงต" are introduced in and are unique to disent! +Disent provides the following Auto-Encoders and Variational Auto-Encoders! -### Frameworks - **Unsupervised**: - + [VAE](https://arxiv.org/abs/1312.6114) - + [Beta-VAE](https://openreview.net/forum?id=Sy2fzU9gl) - + [DFC-VAE](https://arxiv.org/abs/1610.00291) - + [DIP-VAE](https://arxiv.org/abs/1711.00848) - + [InfoVAE](https://arxiv.org/abs/1706.02262) - + [BetaTCVAE](https://arxiv.org/abs/1802.04942) + + AE: _Auto-Encoder_ + + [VAE](https://arxiv.org/abs/1312.6114): Variational Auto-Encoder + + [Beta-VAE](https://openreview.net/forum?id=Sy2fzU9gl): VAE with Scaled Loss + + [DFC-VAE](https://arxiv.org/abs/1610.00291): Deep Feature Consistent VAE + + [DIP-VAE](https://arxiv.org/abs/1711.00848): Disentangled Inferred Prior VAE + + [InfoVAE](https://arxiv.org/abs/1706.02262): Information Maximizing VAE + + [BetaTCVAE](https://arxiv.org/abs/1802.04942): Total Correlation VAE - **Weakly Supervised**: - + [Ada-GVAE](https://arxiv.org/abs/2002.02886) *`AdaVae(..., average_mode='gvae')`* Usually better than the Ada-ML-VAE - + [Ada-ML-VAE](https://arxiv.org/abs/2002.02886) *`AdaVae(..., average_mode='ml-vae')`* + + [Ada-GVAE](https://arxiv.org/abs/2002.02886): Adaptive GVAE, *`AdaVae.cfg(average_mode='gvae')`*, usually better than below! + + [Ada-ML-VAE](https://arxiv.org/abs/2002.02886): Adaptive ML-VAE, *`AdaVae.cfg(average_mode='ml-vae')`* - **Supervised**: - + [TVAE](https://arxiv.org/abs/1802.04403) - -Many popular disentanglement frameworks still need to be added, please -submit an issue if you have a request for an additional framework. + + TAE: _Triplet Auto-Encoder_ + + [TVAE](https://arxiv.org/abs/1802.04403): Triplet Variational Auto-Encoder -
todo

+

๐Ÿ— Todo: Many popular disentanglement frameworks still need to be added, please +submit an issue if you have a request for an additional framework.

+ FactorVAE + GroupVAE @@ -155,6 +185,9 @@ submit an issue if you have a request for an additional framework.

### Metrics +Various metrics are provided by disent that can be used to evaluate the +learnt representations of models that have been trained on ground-truth data. + - **Disentanglement**: + [FactorVAE Score](https://arxiv.org/abs/1802.05983) + [DCI](https://openreview.net/forum?id=By-7dz-AZ) @@ -162,43 +195,14 @@ submit an issue if you have a request for an additional framework. + [SAP](https://arxiv.org/abs/1711.00848) + [Unsupervised Scores](https://github.com/google-research/disentanglement_lib) -Some popular metrics still need to be added, please submit an issue if you wish to -add your own, or you have a request. - -
todo

+

๐Ÿ— Todo: Some popular metrics still need to be added, please submit an issue if you wish to +add your own, or you have a request.

+ [DCIMIG](https://arxiv.org/abs/1910.05587) + [Modularity and Explicitness](https://arxiv.org/abs/1802.05312)

-### Datasets - -Various common datasets used in disentanglement research are included, with hash -verification and automatic chunk-size optimization of underlying hdf5 formats for -low-memory disk-based access. - -- **Ground Truth**: - + Cars3D - + dSprites - + MPI3D - + SmallNORB - + Shapes3D - -- **Ground Truth Synthetic**: - + ๐Ÿงต XYObject: *A simplistic version of dSprites with a single square.* - + ๐Ÿงต XYObjectShaded: *Exact same dataset as XYObject, but ground truth factors have a different representation* - + ๐Ÿงต DSpritesImagenet: *Version of DSprite with foreground or background deterministically masked out with tiny-imagenet data* - -

- XYObject Dataset Factor Traversals -

- - #### Input Transforms + Input/Target Augmentations - - - Input based transforms are supported. - - Input and Target CPU and GPU based augmentations are supported. - ### Schedules & Annealing Hyper-parameter annealing is supported through the use of schedules. @@ -211,6 +215,43 @@ The currently implemented schedules include: ---------------------- +## Architecture + +The disent module structure: + +- `disent.dataset`: dataset wrappers, datasets & sampling strategies + + `disent.dataset.data`: raw datasets + + `disent.dataset.sampling`: sampling strategies for `DisentDataset` when multiple elements are required by frameworks, eg. for triplet loss + + `disent.dataset.transform`: common data transforms and augmentations + + `disent.dataset.wrapper`: wrapped datasets are no longer ground-truth datasets, these may have some elements masked out. We can still unwrap these classes to obtain the original datasets for benchmarking. +- `disent.frameworks`: frameworks, including Auto-Encoders and VAEs + + `disent.frameworks.ae`: Auto-Encoder based frameworks + + `disent.frameworks.vae`: Variational Auto-Encoder based frameworks +- `disent.metrics`: metrics for evaluating disentanglement using ground truth datasets +- `disent.model`: common encoder and decoder models used for VAE research +- `disent.nn`: torch components for building models including layers, transforms, losses and general maths +- `disent.schedule`: annealing schedules that can be registered to a framework +- `disent.util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework. + +**โš ๏ธ The API Is _Mostly_ Stable โš ๏ธ** + +Disent is still under development. Features and APIs are subject to change! +However, I will try and minimise the impact of these. + +A small suite of tests currently exist which will be expanded upon in time. + +**Hydra Experiment Directories** + +Easily run experiments with hydra config, these files +are not available from `pip install`. + +- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config +- `experiment/config/config.yaml`: main configuration file, this is probably what you want to edit! +- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files +- `experiment/util`: various helper code for experiments + +---------------------- + ## Examples ### Python Example @@ -357,7 +398,7 @@ visualisations of latent traversals. ### Why? -- Created as part of my Computer Science MSc scheduled for completion in 2021. +- Created as part of my Computer Science MSc which ended early 2022. - I needed custom high quality implementations of various VAE's. - A pytorch version of [disentanglement_lib](https://github.com/google-research/disentanglement_lib). - I didn't have time to wait for [Weakly-Supervised Disentanglement Without Compromises](https://arxiv.org/abs/2002.02886) to release diff --git a/disent/dataset/__init__.py b/disent/dataset/__init__.py index 84f37c72..e44c1deb 100644 --- a/disent/dataset/__init__.py +++ b/disent/dataset/__init__.py @@ -24,3 +24,4 @@ # wrapper from disent.dataset._base import DisentDataset +from disent.dataset._base import DisentIterDataset diff --git a/disent/dataset/_base.py b/disent/dataset/_base.py index fdb3f985..7ab32153 100644 --- a/disent/dataset/_base.py +++ b/disent/dataset/_base.py @@ -22,7 +22,10 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import warnings from functools import wraps +from typing import Callable +from typing import Iterator from typing import Optional from typing import Sequence from typing import TypeVar @@ -30,6 +33,7 @@ import numpy as np from torch.utils.data import Dataset +from torch.utils.data import IterableDataset from torch.utils.data.dataloader import default_collate from disent.dataset.sampling import BaseDisentSampler @@ -81,19 +85,19 @@ def wrapper(self: 'DisentDataset', *args, **kwargs): # ========================================================================= # -_DO_COPY = object() +_REF_ = object() class DisentDataset(Dataset, LengthIter): - def __init__( self, - dataset: Union[Dataset, GroundTruthData], + dataset: Union[Dataset, GroundTruthData], # TODO: this should be renamed to data sampler: Optional[BaseDisentSampler] = None, - transform=None, - augment=None, - return_indices: bool = False, + transform: Optional[callable] = None, + augment: Optional[callable] = None, + return_indices: bool = False, # doesn't really hurt performance, might as well leave enabled by default? + return_factors: bool = False, ): super().__init__() # save attributes @@ -102,22 +106,37 @@ def __init__( self._transform = transform self._augment = augment self._return_indices = return_indices + self._return_factors = return_factors + # check sampler + assert isinstance(self._sampler, BaseDisentSampler), f'{DisentDataset.__name__} got an invalid {BaseDisentSampler.__name__}: {type(self._sampler)}' # initialize sampler if not self._sampler.is_init: self._sampler.init(dataset) + # warn if we are overriding a transform + if self._transform is not None: + if hasattr(dataset, '_transform') and dataset._transform: + warnings.warn(f'{DisentDataset.__name__} has transform specified as well as wrapped dataset: {dataset}, are you sure this is intended?') + # check the dataset if we are returning the factors + if self._return_factors: + assert isinstance(self._dataset, GroundTruthData), f'If `return_factors` is `True`, then the dataset must be an instance of: {GroundTruthData.__name__}, got: {type(dataset)}' def shallow_copy( self, - transform=_DO_COPY, - augment=_DO_COPY, - return_indices=_DO_COPY, + dataset: Union[Dataset, GroundTruthData] =_REF_, # TODO: this should be renamed to data + sampler: Optional[BaseDisentSampler] = _REF_, + transform: Optional[callable] = _REF_, + augment: Optional[callable] = _REF_, + return_indices: bool = _REF_, + return_factors: bool = _REF_, ) -> 'DisentDataset': + # instantiate shallow dataset copy, overwriting elements if specified return DisentDataset( - dataset=self._dataset, - sampler=self._sampler, - transform=self._transform if (transform is _DO_COPY) else transform, - augment=self._augment if (augment is _DO_COPY) else augment, - return_indices=self._return_indices if (return_indices is _DO_COPY) else return_indices, + dataset = self._dataset if (dataset is _REF_) else dataset, + sampler = self._sampler.uninit_copy() if (sampler is _REF_) else sampler, + transform = self._transform if (transform is _REF_) else transform, + augment = self._augment if (augment is _REF_) else augment, + return_indices = self._return_indices if (return_indices is _REF_) else return_indices, + return_factors = self._return_factors if (return_factors is _REF_) else return_factors, ) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -128,6 +147,18 @@ def shallow_copy( def data(self) -> Dataset: return self._dataset + @property + def sampler(self) -> BaseDisentSampler: + return self._sampler + + @property + def transform(self) -> Optional[Callable[[object], object]]: + return self._transform + + @property + def augment(self) -> Optional[Callable[[object], object]]: + return self._augment + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Ground Truth Only # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -176,15 +207,22 @@ def wrapped_gt_data(self): return self._dataset.gt_data @wrapped_only - def unwrapped_disent_dataset(self) -> 'DisentDataset': - sampler = self._sampler.uninit_copy() - assert type(sampler) is type(self._sampler) - return DisentDataset( + def unwrapped_shallow_copy( + self, + sampler: Optional[BaseDisentSampler] = _REF_, + transform: Optional[callable] = _REF_, + augment: Optional[callable] = _REF_, + return_indices: bool = _REF_, + return_factors: bool = _REF_, + ) -> 'DisentDataset': + # like shallow_copy, but unwrap the dataset instead! + return self.shallow_copy( dataset=self.wrapped_data, sampler=sampler, - transform=self._transform, - augment=self._augment, - return_indices=self._return_indices, + transform=transform, + augment=augment, + return_indices=return_indices, + return_factors=return_factors, ) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -275,6 +313,12 @@ def _dataset_get_observation(self, *idxs): # add indices if self._return_indices: obs['idx'] = idxs + # add factors + if self._return_factors: + # >>> this is about 10% faster than below, because we do not need to do conversions! + obs['factors'] = tuple(np.array(np.unravel_index(idxs, self._dataset.factor_sizes)).T) + # >>> builtin but slower method, does some magic for more than 2 dims, could replace with faster try_njit method, but then we need numba! + # obs['factors1'] = tuple(self.gt_data.idx_to_pos(idxs)) # done! return obs @@ -285,38 +329,57 @@ def _dataset_get_observation(self, *idxs): # TODO: default_collate should be replaced with a function # that can handle tensors and nd.arrays, and return accordingly - def dataset_batch_from_indices(self, indices: Sequence[int], mode: str): + def dataset_batch_from_indices(self, indices: Sequence[int], mode: str, collate: bool = True): """Get a batch of observations X from a batch of factors Y.""" - return default_collate([self.dataset_get(idx, mode=mode) for idx in indices]) + batch = [self.dataset_get(idx, mode=mode) for idx in indices] + return default_collate(batch) if collate else batch - def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False): + def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False, collate: bool = True, seed: Optional[int] = None): """Sample a batch of observations X.""" # built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672 - indices = random_choice_prng(len(self), size=num_samples, replace=replace) + indices = random_choice_prng(len(self._dataset), size=num_samples, replace=replace, seed=seed) # return batch - batch = self.dataset_batch_from_indices(indices, mode=mode) + batch = self.dataset_batch_from_indices(indices, mode=mode, collate=collate) # return values if return_indices: - return batch, default_collate(indices) + return batch, (default_collate(indices) if collate else indices) else: return batch + def dataset_sample_elems(self, num_samples: int, mode: str, return_indices: bool = False, seed: Optional[int] = None): + """Sample uncollated elements with replacement, like `dataset_sample_batch`""" + return self.dataset_sample_batch(num_samples=num_samples, mode=mode, replace=True, return_indices=return_indices, collate=False, seed=seed) + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Batches -- Ground Truth Only # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @groundtruth_only - def dataset_batch_from_factors(self, factors: np.ndarray, mode: str): + def dataset_batch_from_factors(self, factors: np.ndarray, mode: str, collate: bool = True): """Get a batch of observations X from a batch of factors Y.""" indices = self.gt_data.pos_to_idx(factors) - return self.dataset_batch_from_indices(indices, mode=mode) + return self.dataset_batch_from_indices(indices, mode=mode, collate=collate) @groundtruth_only - def dataset_sample_batch_with_factors(self, num_samples: int, mode: str): + def dataset_sample_batch_with_factors(self, num_samples: int, mode: str, collate: bool = True): """Sample a batch of observations X and factors Y.""" factors = self.gt_data.sample_factors(num_samples) - batch = self.dataset_batch_from_factors(factors, mode=mode) - return batch, default_collate(factors) + batch = self.dataset_batch_from_factors(factors, mode=mode, collate=collate) + return batch, (default_collate(factors) if collate else factors) + + +class DisentIterDataset(IterableDataset, DisentDataset): + + # make sure we cannot obtain the length directly + __len__ = None + + def __iter__(self): + # this takes priority over __getitem__, otherwise __getitem__ would need to + # raise an IndexError if out of bounds to signal the end of iteration + while True: + # yield the entire dataset + # - repeating when it is done! + yield from (self[i] for i in range(len(self._dataset))) # ========================================================================= # diff --git a/disent/dataset/data/__init__.py b/disent/dataset/data/__init__.py index 1a870327..4dab1de0 100644 --- a/disent/dataset/data/__init__.py +++ b/disent/dataset/data/__init__.py @@ -43,9 +43,11 @@ # groundtruth -- impl from disent.dataset.data._groundtruth__cars3d import Cars3dData +from disent.dataset.data._groundtruth__cars3d import Cars3d64Data # optimized version of cars3d for 64x64 images from disent.dataset.data._groundtruth__dsprites import DSpritesData from disent.dataset.data._groundtruth__mpi3d import Mpi3dData from disent.dataset.data._groundtruth__norb import SmallNorbData +from disent.dataset.data._groundtruth__norb import SmallNorb64Data from disent.dataset.data._groundtruth__shapes3d import Shapes3dData # groundtruth -- impl synthetic diff --git a/disent/dataset/data/_groundtruth.py b/disent/dataset/data/_groundtruth.py index 0c269ba4..faa86ba1 100644 --- a/disent/dataset/data/_groundtruth.py +++ b/disent/dataset/data/_groundtruth.py @@ -85,6 +85,15 @@ def factor_names(self) -> Tuple[str, ...]: def factor_sizes(self) -> Tuple[int, ...]: raise NotImplementedError() + def state_space_copy(self) -> StateSpace: + """ + :return: Copy this ground truth dataset as a StateSpace, discarding everything else! + """ + return StateSpace( + factor_sizes=self.factor_sizes, + factor_names=self.factor_names, + ) + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Properties # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -214,7 +223,7 @@ def _mixin_disk_init(self, data_root: Optional[str] = None, prepare: bool = Fals if data_root is None: data_root = self.default_data_root else: - data_root = os.path.abspath(data_root) + data_root = os.path.abspath(os.path.expanduser(data_root)) # get class data folder self._data_dir = ensure_dir_exists(os.path.join(data_root, self.name)) log.info(f'{self.name}: data_dir_share={repr(self._data_dir)}') diff --git a/disent/dataset/data/_groundtruth__cars3d.py b/disent/dataset/data/_groundtruth__cars3d.py index 522bc749..021c2123 100644 --- a/disent/dataset/data/_groundtruth__cars3d.py +++ b/disent/dataset/data/_groundtruth__cars3d.py @@ -26,13 +26,16 @@ import os import shutil from tempfile import TemporaryDirectory +from typing import Dict +from typing import Optional +from typing import Union import numpy as np -from scipy.io import loadmat -from disent.dataset.util.datafile import DataFileHashedDlGen from disent.dataset.data._groundtruth import NumpyFileGroundTruthData -from disent.util.inout.files import AtomicSaveFile +from disent.dataset.util.datafile import DataFileHashed +from disent.dataset.util.datafile import DataFileHashedDlGen +from disent.util.inout.paths import modify_name_keep_ext log = logging.getLogger(__name__) @@ -52,6 +55,7 @@ def load_cars3d_folder(raw_data_dir): 2. /data/sprites 3. /data/shapes48.mat """ + from scipy.io import loadmat # load image paths with open(os.path.join(raw_data_dir, 'cars/list.txt'), 'r') as img_names: img_paths = [os.path.join(raw_data_dir, f'cars/{name.strip()}.mat') for name in img_names.readlines()] @@ -72,10 +76,20 @@ def resave_cars3d_archive(orig_zipped_file, new_save_file, overwrite=False): # extract zipfile and get path log.info(f"Extracting into temporary directory: {temp_dir}") shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir) - # load image paths & resave - with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file: - images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data')) - np.savez(temp_file, images=images) + # load images + images = load_cars3d_folder(raw_data_dir=os.path.join(temp_dir, 'data')) + # save the array + from disent.dataset.util.npz import save_dataset_array + save_dataset_array(images, new_save_file, overwrite=overwrite, save_key='images') + + +def resave_cars3d_resized(orig_converted_file: str, new_resized_file: str, overwrite=False, size: int = 64): + # load the array + cars3d_array = np.load(orig_converted_file)['images'] + assert cars3d_array.shape == (17568, 128, 128, 3) + # save the array + from disent.dataset.util.npz import save_resized_dataset_array + save_resized_dataset_array(cars3d_array, new_resized_file, overwrite=overwrite, size=size, save_key='images') # ========================================================================= # @@ -91,6 +105,35 @@ def _generate(self, inp_file: str, out_file: str): resave_cars3d_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True) +class DataFileCars3dResized(DataFileHashed): + + def __init__( + self, + cars3d_datafile: DataFileCars3d, + # - convert file name + out_hash: Optional[Union[str, Dict[str, str]]], + out_name: Optional[str] = None, + out_size: int = 64, + # - hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._out_size = out_size + self._cars3dfile = cars3d_datafile + super().__init__( + file_name=modify_name_keep_ext(self._cars3dfile.out_name, suffix=f'_x{out_size}') if (out_name is None) else out_name, + file_hash=out_hash, + hash_type=hash_type, + hash_mode=hash_mode, + ) + + def _prepare(self, out_dir: str, out_file: str): + log.debug('Preparing Orig Cars3d Data:') + cars3d_path = self._cars3dfile.prepare(out_dir) + log.debug('Generating Resized Cars3d Data:') + resave_cars3d_resized(orig_converted_file=cars3d_path, new_resized_file=out_file, overwrite=True, size=self._out_size) + + # ========================================================================= # # dataset_cars3d # # ========================================================================= # @@ -115,7 +158,7 @@ class Cars3dData(NumpyFileGroundTruthData): uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz', uri_hash={'fast': 'fe77d39e3fa9d77c31df2262660c2a67', 'full': '4e866a7919c1beedf53964e6f7a23686'}, file_name='cars3d.npz', - file_hash={'fast': 'ef5d86d1572ddb122b466ec700b3abf2', 'full': 'dc03319a0b9118fbe0e23d13220a745b'}, + file_hash={'fast': '204ecb6852216e333f1b022903f9d012', 'full': '46ad66acf277897f0404e522460ba7e5'}, hash_mode='fast' ) @@ -123,11 +166,50 @@ class Cars3dData(NumpyFileGroundTruthData): data_key = 'images' +# TODO: this is very slow compared to other datasets for some reason! +# - in memory benchmark are equivalent, eg. against Shapes3D, but when we run the +# experiment/run.py with this its about twice as slow? Why is this? +class Cars3d64Data(Cars3dData): + """ + Optimized version of Cars3dOrigData, that has already been re-sized to 64x64 + - This can improve run times dramatically! + """ + + img_shape = (64, 64, 3) + + datafile = DataFileCars3dResized( + cars3d_datafile=Cars3dData.datafile, + out_name='cars3d_x64.npz', + out_hash={'fast': '5a85246b6f555bc6e3576ee62bf6d19e', 'full': '2b900b3c5de6cd9b5df87bfc02f01f03'}, + hash_mode='fast', + out_size=64, + ) + + # ========================================================================= # # END # # ========================================================================= # if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) - Cars3dData(prepare=True) + + def main(): + import torch + from tqdm import tqdm + from disent.dataset.transform import ToImgTensorF32 + + logging.basicConfig(level=logging.DEBUG) + + # original dataset + data_128 = Cars3dData(prepare=True, transform=ToImgTensorF32(size=64)) + for i in tqdm(data_128, desc='cars3d_x128 -> 64'): + pass + # resized dataset + data_64 = Cars3d64Data(prepare=True, transform=ToImgTensorF32(size=64)) + for i in tqdm(data_64, desc='cars3d_x64'): + pass + # check equivalence + for obs_128, obs_64 in tqdm(zip(data_128, data_64), desc='equivalence'): + assert torch.allclose(obs_128, obs_64) + + main() diff --git a/disent/dataset/data/_groundtruth__dsprites.py b/disent/dataset/data/_groundtruth__dsprites.py index 99ee76ae..d0867414 100644 --- a/disent/dataset/data/_groundtruth__dsprites.py +++ b/disent/dataset/data/_groundtruth__dsprites.py @@ -33,6 +33,7 @@ # ========================================================================= # +# TODO: this seems to have a memory leak compared to other datasets? class DSpritesData(Hdf5GroundTruthData): """ DSprites Dataset diff --git a/disent/dataset/data/_groundtruth__dsprites_imagenet.py b/disent/dataset/data/_groundtruth__dsprites_imagenet.py deleted file mode 100644 index 13b80134..00000000 --- a/disent/dataset/data/_groundtruth__dsprites_imagenet.py +++ /dev/null @@ -1,290 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import os -import shutil -from tempfile import TemporaryDirectory -from typing import Optional - -import numpy as np -import psutil -from torch.utils.data import DataLoader -from torch.utils.data import Dataset -from torchvision.datasets import ImageFolder -from tqdm import tqdm - -from disent.dataset.data import GroundTruthData -from disent.dataset.data._groundtruth import _DiskDataMixin -from disent.dataset.data._groundtruth import _Hdf5DataMixin -from disent.dataset.data._groundtruth__dsprites import DSpritesData -from disent.dataset.transform import ToImgTensorF32 -from disent.dataset.util.datafile import DataFileHashedDlGen -from disent.dataset.util.hdf5 import H5Builder -from disent.dataset.util.stats import compute_data_mean_std -from disent.util.inout.files import AtomicSaveFile -from disent.util.iters import LengthIter -from disent.util.math.random import random_choice_prng - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# load imagenet-tiny data # -# ========================================================================= # - - -class NumpyFolder(ImageFolder): - def __getitem__(self, idx): - img, cls = super().__getitem__(idx) - return np.array(img) - - -def _noop(x): - return x - - -def load_imagenet_tiny_data(raw_data_dir): - data = NumpyFolder(os.path.join(raw_data_dir, 'train')) - data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=_noop) - # load data - this is a bit memory inefficient doing it like this instead of with a loop into a pre-allocated array - imgs = np.concatenate(list(tqdm(data, 'loading')), axis=0) - assert imgs.shape == (100_000, 64, 64, 3) - return imgs - - -def resave_imagenet_tiny_archive(orig_zipped_file, new_save_file, overwrite=False, h5_dataset_name: str = 'data'): - """ - Convert a imagenet tiny archive to an hdf5 or numpy file depending on the file extension. - Uncompressing the contents of the archive into a temporary directory in the same folder, - loading the images, then converting. - """ - _, ext = os.path.splitext(new_save_file) - assert ext in {'.npz', '.h5'}, f'unsupported save extension: {repr(ext)}, must be one of: {[".npz", ".h5"]}' - # extract zipfile into temp dir - with TemporaryDirectory(prefix='unzip_imagenet_tiny_', dir=os.path.dirname(orig_zipped_file)) as temp_dir: - log.info(f"Extracting into temporary directory: {temp_dir}") - shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir) - images = load_imagenet_tiny_data(raw_data_dir=os.path.join(temp_dir, 'tiny-imagenet-200')) - # save the data - with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file: - # check the mode - with H5Builder(temp_file, 'atomic_w') as builder: - builder.add_dataset_from_array( - name=h5_dataset_name, - array=images, - chunk_shape='batch', - compression_lvl=4, - attrs=None, - show_progress=True, - ) - - -# ========================================================================= # -# cars3d data object # -# ========================================================================= # - - -class ImageNetTinyDataFile(DataFileHashedDlGen): - """ - download the cars3d dataset and convert it to a hdf5 file. - """ - - dataset_name: str = 'data' - - def _generate(self, inp_file: str, out_file: str): - resave_imagenet_tiny_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True, h5_dataset_name=self.dataset_name) - - -class ImageNetTinyData(_Hdf5DataMixin, _DiskDataMixin, Dataset, LengthIter): - - name = 'imagenet_tiny' - - datafile_imagenet_h5 = ImageNetTinyDataFile( - uri='http://cs231n.stanford.edu/tiny-imagenet-200.zip', - uri_hash={'fast': '4d97ff8efe3745a3bba9917d6d536559', 'full': '90528d7ca1a48142e341f4ef8d21d0de'}, - file_hash={'fast': '9c23e8ec658b1ec9f3a86afafbdbae51', 'full': '4c32b0b53f257ac04a3afb37e3a4204e'}, - uri_name='tiny-imagenet-200.zip', - file_name='tiny-imagenet-200.h5', - hash_mode='full' - ) - - datafiles = (datafile_imagenet_h5,) - - def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None): - super().__init__() - self._transform = transform - # initialize mixin - self._mixin_disk_init( - data_root=data_root, - prepare=prepare, - ) - self._mixin_hdf5_init( - h5_path=os.path.join(self.data_dir, self.datafile_imagenet_h5.out_name), - h5_dataset_name=self.datafile_imagenet_h5.dataset_name, - in_memory=in_memory, - ) - - def __getitem__(self, idx: int): - obs = self._data[idx] - if self._transform is not None: - obs = self._transform(obs) - return obs - - -# ========================================================================= # -# dataset_dsprites # -# ========================================================================= # - - -class DSpritesImagenetData(GroundTruthData): - """ - DSprites that has imagenet images in the background. - """ - - name = 'dsprites_imagenet' - - # original dsprites it only (64, 64, 1) imagenet adds the colour channel - img_shape = (64, 64, 3) - factor_names = DSpritesData.factor_names - factor_sizes = DSpritesData.factor_sizes - - def __init__(self, visibility: int = 100, mode: str = 'bg', data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None): - super().__init__(transform=transform) - # check visibility and convert to ratio - assert isinstance(visibility, int), f'incorrect visibility percentage type, expected int, got: {type(visibility)}' - assert 0 <= visibility <= 100, f'incorrect visibility percentage: {repr(visibility)}, must be in range [0, 100]. ' - self._visibility = visibility / 100 - # check mode and convert to foreground boolean - assert mode in {'bg', 'fg'}, f'incorrect mode: {repr(mode)}, must be one of: ["bg", "fg"]' - self._foreground = (mode == 'fg') - # handle the datasets - self._dsprites = DSpritesData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None) - self._imagenet = ImageNetTinyData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None) - # deterministic randomization of the imagenet order - self._imagenet_order = random_choice_prng( - len(self._imagenet), - size=len(self), - seed=42, - ) - - def _get_observation(self, idx): - # we need to combine the two dataset images - # dsprites contains only {0, 255} for values - # we can directly use these values to mask the imagenet image - bg = self._imagenet[self._imagenet_order[idx]] - fg = self._dsprites[idx].repeat(3, axis=-1) - # compute background - # set foreground - r = self._visibility - if self._foreground: - # lerp content to white, and then insert into fg regions - # r*bg + (1-r)*255 - obs = (r*bg + ((1-r)*255)).astype('uint8') - obs[fg <= 127] = 0 - else: - # lerp content to black, and then insert into bg regions - # r*bg + (1-r)*000 - obs = (r*bg).astype('uint8') - obs[fg > 127] = 255 - # checks - return obs - - -# ========================================================================= # -# STATS # -# ========================================================================= # - - -""" -dsprites_fg_1.0 - vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384] - vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281] -dsprites_fg_0.8 - vis_mean: [0.024956427531012196, 0.02336780403840578, 0.021475119672280243] - vis_std: [0.11864125016313823, 0.11137998105649799, 0.10281424917834255] -dsprites_fg_0.6 - vis_mean: [0.029335176871153983, 0.028145355435322966, 0.026731731769287146] - vis_std: [0.13663242436043319, 0.13114320478634894, 0.1246542727733097] -dsprites_fg_0.4 - vis_mean: [0.03369999506331255, 0.03290657349801835, 0.03196482946320608] - vis_std: [0.155514074438101, 0.1518464537731621, 0.14750944591836743] -dsprites_fg_0.2 - vis_mean: [0.038064750024334834, 0.03766780505193579, 0.03719798677641122] - vis_std: [0.17498878664096565, 0.17315570657628318, 0.1709923319496426] -dsprites_bg_1.0 - vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404] - vis_std: [0.2505510666843685, 0.25007259803668697, 0.2562415603123114] -dsprites_bg_0.8 - vis_mean: [0.40867981393820857, 0.38468564002021527, 0.34611573047508204] - vis_std: [0.22048328737091344, 0.22102216869942384, 0.22692977053753477] -dsprites_bg_0.6 - vis_mean: [0.31676960943447674, 0.29877166834408025, 0.2698556821388113] - vis_std: [0.19745897110349003, 0.1986606891520453, 0.203808842880044] -dsprites_bg_0.4 - vis_mean: [0.2248598986983768, 0.21285772298967615, 0.19359577132944206] - vis_std: [0.1841631708032332, 0.18554895825833284, 0.1893568926398198] -dsprites_bg_0.2 - vis_mean: [0.13294969414492142, 0.12694375140936273, 0.11733572285575933] - vis_std: [0.18311250427586276, 0.1840916474752131, 0.18607373519458442] -""" - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - - - def compute_all_stats(): - from disent.util.visualize.plot import plt_subplots_imshow - - def compute_stats(visibility, mode): - # plot images - data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode) - grid = np.array([data[i*24733] for i in np.arange(16)]).reshape([4, 4, *data.img_shape]) - plt_subplots_imshow(grid, show=True, title=f'{DSpritesImagenetData.name} visibility={repr(visibility)} mode={repr(mode)}') - # compute stats - name = f'dsprites_{mode}_{visibility}' - data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode, transform=ToImgTensorF32()) - mean, std = compute_data_mean_std(data, batch_size=256, num_workers=min(psutil.cpu_count(logical=False), 64), progress=True) - print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}') - # return stats - return name, mean, std - - # compute common stats - stats = [] - for mode in ['fg', 'bg']: - for vis in [100, 80, 60, 40, 20]: - stats.append(compute_stats(vis, mode)) - - # print once at end - for name, mean, std in stats: - print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}') - - compute_all_stats() - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/dataset/data/_groundtruth__norb.py b/disent/dataset/data/_groundtruth__norb.py index c96d0505..9318b0e7 100644 --- a/disent/dataset/data/_groundtruth__norb.py +++ b/disent/dataset/data/_groundtruth__norb.py @@ -22,17 +22,23 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import gzip import logging -import os +from typing import Dict +from typing import NoReturn from typing import Optional -from typing import Sequence from typing import Tuple +from typing import Union import numpy as np +from disent.dataset.data import NumpyFileGroundTruthData +from disent.dataset.util.datafile import DataFile +from disent.dataset.util.datafile import DataFileHashed from disent.dataset.util.datafile import DataFileHashedDl -from disent.dataset.data._groundtruth import DiskGroundTruthData +from disent.util.inout.paths import modify_name_keep_ext + + +log = logging.getLogger(__name__) # ========================================================================= # @@ -81,6 +87,7 @@ def read_binary_matrix_bytes(bytes): def read_binary_matrix_file(file, gzipped: bool = True): + import gzip # this does not seem to copy the bytes, which saves memory with (gzip.open if gzipped else open)(file, "rb") as f: return read_binary_matrix_bytes(bytes=f.read()) @@ -91,11 +98,12 @@ def read_binary_matrix_file(file, gzipped: bool = True): # ========================================================================= # -def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True, sort=True) -> Tuple[np.ndarray, np.ndarray]: +def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True, sort=True, add_channel_dim: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Load The Normalised Dataset * dat: - images (5 categories, 5 instances, 6 lightings, 9 elevations, and 18 azimuths) + + shape: (N, H, W, 1) * cat: - initial ground truth factor: 0. category of images (0 for animal, 1 for human, 2 for plane, 3 for truck, 4 for car). @@ -119,16 +127,106 @@ def read_norb_dataset(dat_path: str, cat_path: str, info_path: str, gzipped=True indices = np.lexsort(factors[:, [4, 3, 2, 1, 0]].T) images = images[indices] factors = factors[indices] + # add the channel dimension + if add_channel_dim: + images = images[:, :, :, None] + assert images.ndim == 4 + else: + assert images.ndim == 3 # done! return images, factors +def resave_norb_archive(in_dat_path: str, in_cat_path: str, in_info_path: str, new_save_file: str, in_gzipped=True, overwrite: bool = False): + # load the array + images, factors = read_norb_dataset(dat_path=in_dat_path, cat_path=in_cat_path, info_path=in_info_path, gzipped=in_gzipped, sort=True, add_channel_dim=True) + assert images.shape == (24300, 96, 96, 1) + # save the array + from disent.dataset.util.npz import save_dataset_array + save_dataset_array(images, new_save_file, overwrite=overwrite, save_key='images') + + +def resave_norb_resized(orig_converted_file: str, new_resized_file: str, overwrite=False, size: int = 64): + # load the array + norb_array = np.load(orig_converted_file)['images'] + assert norb_array.shape == (24300, 96, 96, 1) + # save the array + from disent.dataset.util.npz import save_resized_dataset_array + save_resized_dataset_array(norb_array, new_resized_file, overwrite=overwrite, size=size, save_key='images') + + +# ========================================================================= # +# Data Files # +# ========================================================================= # + + +class DataFileSmallNorb(DataFileHashed): + """ + download the smallnorb dataset and convert it to a numpy file. + """ + + def __init__( + self, + datafile_dat: DataFile, + datafile_cat: DataFile, + datafile_info: DataFile, + out_name: str, + out_hash: Optional[Union[str, Dict[str, str]]], + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._datafile_dat = datafile_dat + self._datafile_cat = datafile_cat + self._datafile_info = datafile_info + # initialize + super().__init__(file_name=out_name, file_hash=out_hash, hash_type=hash_type, hash_mode=hash_mode) + + def _prepare(self, out_dir: str, out_file: str) -> NoReturn: + resave_norb_archive( + in_dat_path=self._datafile_dat.prepare(out_dir), + in_cat_path=self._datafile_cat.prepare(out_dir), + in_info_path=self._datafile_info.prepare(out_dir), + new_save_file=out_file, + in_gzipped=True, + overwrite=True, + ) + + +class DataFileSmallNorbResized(DataFileHashed): + + def __init__( + self, + norb_datafile: DataFileSmallNorb, + # - convert file name + out_hash: Optional[Union[str, Dict[str, str]]], + out_name: Optional[str] = None, + out_size: int = 64, + # - hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._out_size = out_size + self._norb_datafile = norb_datafile + super().__init__( + file_name=modify_name_keep_ext(self._norb_datafile.out_name, suffix=f'_x{out_size}') if (out_name is None) else out_name, + file_hash=out_hash, + hash_type=hash_type, + hash_mode=hash_mode, + ) + + def _prepare(self, out_dir: str, out_file: str): + log.debug('Preparing Orig SmallNorb Data:') + norb_path = self._norb_datafile.prepare(out_dir) + log.debug('Generating Resized SmallNorb Data:') + resave_norb_resized(orig_converted_file=norb_path, new_resized_file=out_file, overwrite=True, size=self._out_size) + + # ========================================================================= # # dataset_norb # # ========================================================================= # -class SmallNorbData(DiskGroundTruthData): +class SmallNorbData(NumpyFileGroundTruthData): """ Small NORB Dataset - https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/ @@ -143,33 +241,56 @@ class SmallNorbData(DiskGroundTruthData): factor_sizes = (5, 5, 9, 18, 6) # TOTAL: 24300 img_shape = (96, 96, 1) - TRAIN_DATA_FILES = { - 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), - 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), - 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), - } - - TEST_DATA_FILES = { - 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), - 'cat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), - 'info': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), - } - - def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_test=False, transform=None): + DATA_FILE_TRAIN = DataFileSmallNorb( + datafile_dat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), + datafile_cat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat.gz', uri_hash={'fast': '348fc3ccefd651d69f500611988b5dcd', 'full': '23c8b86101fbf0904a000b43d3ed2fd9'}), + datafile_info=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-info.mat.gz', uri_hash={'fast': 'f1b170c16925867c05f58608eb33ba7f', 'full': '51dee1210a742582ff607dfd94e332e3'}), + out_name='smallnorb_train.npz', + out_hash={'fast': 'a2c7de23c57b16c71b79dc2c884ecd67', 'full': '7dabafbfafa0eb9b0115452f82d1491e'}, + hash_mode='fast', + ) + + DATA_FILE_TEST = DataFileSmallNorb( + datafile_dat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat.gz', uri_hash={'fast': '9aee0b474a4fc2a2ec392b463efb8858', 'full': 'e4ad715691ed5a3a5f138751a4ceb071'}), + datafile_cat=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat.gz', uri_hash={'fast': '8cfae0679f5fa2df7a0aedfce90e5673', 'full': '5aa791cd7e6016cf957ce9bdb93b8603'}), + datafile_info=DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat.gz', uri_hash={'fast': 'd2703a3f95e7b9a970ad52e91f0aaf6a', 'full': 'a9454f3864d7fd4bb3ea7fc3eb84924e'}), + out_name='smallnorb_test.npz', + out_hash={'fast': 'ff027c01e14faea9a0d427d641d3bd8e', 'full': 'cec16d74a1b075fe4117b5167e16ceff'}, + hash_mode='fast', + ) + + # override + data_key = 'images' + + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_test: bool = False, transform=None): self._is_test = is_test # initialize super().__init__(data_root=data_root, prepare=prepare, transform=transform) - # read dataset and sort by features - dat_path, cat_path, info_path = (os.path.join(self.data_dir, obj.out_name) for obj in self.datafiles) - self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path) - - def _get_observation(self, idx): - return self._data[idx][:, :, None] # data is missing channel dim @property - def datafiles(self) -> Sequence[DataFileHashedDl]: - norb_objects = self.TEST_DATA_FILES if self._is_test else self.TRAIN_DATA_FILES - return norb_objects['dat'], norb_objects['cat'], norb_objects['info'] + def datafile(self) -> DataFile: + return self.DATA_FILE_TEST if self._is_test else self.DATA_FILE_TRAIN + + +class SmallNorb64Data(SmallNorbData): + + img_shape = (64, 64, 1) + + DATA_FILE_TRAIN = DataFileSmallNorbResized( + norb_datafile=SmallNorbData.DATA_FILE_TRAIN, + out_name='smallnorb_train_x64.npz', + out_size=64, + out_hash={'fast': '74a3c02ea5a649313ea245a3fe271d3b', 'full': '88ce361b2198ee577e60da2be9daa0e8'}, + hash_mode='fast', + ) + + DATA_FILE_TEST = DataFileSmallNorbResized( + norb_datafile=SmallNorbData.DATA_FILE_TEST, + out_name='smallnorb_test_x64.npz', + out_size=64, + out_hash={'fast': '37bf364479c0954ecd707ace349541ef', 'full': '6bfd93eb6454d9d24dba13cac5f1ef3e'}, + hash_mode='fast', + ) # ========================================================================= # @@ -178,5 +299,21 @@ def datafiles(self) -> Sequence[DataFileHashedDl]: if __name__ == '__main__': + import torch + from tqdm import tqdm + from disent.dataset.transform import ToImgTensorF32 + logging.basicConfig(level=logging.DEBUG) - SmallNorbData(prepare=True) + + for is_test in [False, True]: + # original dataset + data_96 = SmallNorbData(prepare=True, is_test=is_test, transform=ToImgTensorF32(size=64)) + for i in tqdm(data_96, desc='norb_x96 -> 64'): + pass + # resized dataset + data_64 = SmallNorb64Data(prepare=True, is_test=is_test, transform=ToImgTensorF32(size=64)) + for i in tqdm(data_64, desc='norb_x64'): + pass + # check equivalence + for obs_96, obs_64 in tqdm(zip(data_96, data_64), desc='equivalence'): + assert torch.allclose(obs_96, obs_64) diff --git a/disent/dataset/data/_groundtruth__teapots3d.py b/disent/dataset/data/_groundtruth__teapots3d.py new file mode 100644 index 00000000..afaf991f --- /dev/null +++ b/disent/dataset/data/_groundtruth__teapots3d.py @@ -0,0 +1,164 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import Union + +import numpy as np + +from disent.dataset.data import NumpyFileGroundTruthData +from disent.dataset.util.datafile import DataFileHashed +from disent.util.inout.paths import modify_name_keep_ext + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# teapots 3d data processing # +# ========================================================================= # + + +def resave_teapots3d_as_uint8(orig_file: str, new_file: str, overwrite: bool = False): + # load data into memory ~10GB + # -- by default this array is stored as uint32, instead of uint8 + log.debug('loading teapots data into memory, this may take a while...') + imgs = np.load(orig_file)['images'] + log.debug('loaded teapots data into memory!') + # checks + log.debug('checking teapots data...') + assert imgs.dtype == 'int32' + assert imgs.shape == (200_000, 64, 64, 3) + assert imgs.max() == 255 + assert imgs.min() == 0 + log.debug('checked teapots data!') + # convert the values + log.debug('converting teapots data to uint8...') + imgs = imgs.astype('uint8') + log.debug('converted teapots data!') + # save the array + from disent.dataset.util.npz import save_dataset_array + log.debug('saving convert teapots data...') + save_dataset_array(imgs, new_file, overwrite=overwrite, save_key='images') + log.debug('saved convert teapots data!') + + +# ========================================================================= # +# teapots 3d data files # +# ========================================================================= # + + +class DataFileTeapots3dInt32(DataFileHashed): + + # TODO: add a version of this file that automatically unpacks the original zip file? + + def _prepare(self, out_dir: str, out_file: str) -> NoReturn: + if not os.path.exists(out_file): + raise FileNotFoundError( + f'Please download the Teapots3D dataset to: {repr(out_file)}' + f'\nThe original repository is: {repr("https://github.com/cianeastwood/qedr")}' + f'\nThe original download link is: {repr("https://www.dropbox.com/s/woeyomxuylqu7tx/edinburgh_teapots.zip?dl=0")}' + ) + + +class DataFileTeapots3dUint8(DataFileHashed): + + def __init__( + self, + teapots3d_datafile: DataFileTeapots3dInt32, + # - convert file name + out_hash: Optional[Union[str, Dict[str, str]]], + out_name: Optional[str] = None, + # - hash settings + hash_type: str = 'md5', + hash_mode: str = 'fast', + ): + self._teapots3dfile = teapots3d_datafile + super().__init__( + file_name=modify_name_keep_ext(self._teapots3dfile.out_name, suffix=f'_uint8') if (out_name is None) else out_name, + file_hash=out_hash, + hash_type=hash_type, + hash_mode=hash_mode, + ) + + def _prepare(self, out_dir: str, out_file: str): + log.debug('Preparing Orig Teapots3d Data:') + cars3d_path = self._teapots3dfile.prepare(out_dir) + log.debug('Converting Teapots3d Data to Uint8:') + resave_teapots3d_as_uint8(orig_file=cars3d_path, new_file=out_file, overwrite=True) + + +# ========================================================================= # +# teapots 3d dataset # +# ========================================================================= # + + +class Teapots3dData(NumpyFileGroundTruthData): + """ + Teapots3D Dataset + - A Framework for the Quantitative Evaluation of Disentangled Representations + * https://openreview.net/forum?id=By-7dz-AZ + * https://github.com/cianeastwood/qedr + + Manual Download Link: + - https://www.dropbox.com/s/woeyomxuylqu7tx/edinburgh_teapots.zip?dl=0 + + NOTE: + - This dataset is generated from ground-truth factors, HOWEVER, each datapoint + is randomly sampled. This dataset is NOT a typical grid-search over ground-truth factors + which means that we cannot create a StateSpace object over this dataset. + """ + + name = 'edinburgh_teapots' + + factor_names = ('azimuth', 'elevation', 'red', 'green', 'blue') + factor_sizes = (..., ..., ..., ..., ...) # TOTAL: 200_000 -- TODO: this is invalid, we cannot actually generate a StateSpace object over this dataset! + img_shape = (64, 64, 3) + + datafile = DataFileTeapots3dUint8( + teapots3d_datafile=DataFileTeapots3dInt32( + file_name='teapots.npz', + file_hash={'full': '9b58d66a382d01f4477e33520f1fa503', 'fast': '12c889e001c205d0bafa59dfff114102'}, + ), + out_hash={'full': 'e64207ee443030d310500d762f0d1dfd', 'fast': '7fbca0223c27e055d35b6d5af720f108'}, + out_name='teapots_uint8.npz', + ) + + # override + data_key = 'images' + + +# ========================================================================= # +# main # +# ========================================================================= # + + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + data = Teapots3dData(data_root='~/Downloads', prepare=True) diff --git a/disent/dataset/data/_raw.py b/disent/dataset/data/_raw.py index 219cda98..5a30f45c 100644 --- a/disent/dataset/data/_raw.py +++ b/disent/dataset/data/_raw.py @@ -29,7 +29,7 @@ # ========================================================================= # -# hdf5 utils # +# array utils # # ========================================================================= # diff --git a/disent/dataset/sampling/__init__.py b/disent/dataset/sampling/__init__.py index b9bcaf8f..5dd9a9c0 100644 --- a/disent/dataset/sampling/__init__.py +++ b/disent/dataset/sampling/__init__.py @@ -31,6 +31,7 @@ from disent.dataset.sampling._groundtruth__pair_orig import GroundTruthPairOrigSampler from disent.dataset.sampling._groundtruth__single import GroundTruthSingleSampler from disent.dataset.sampling._groundtruth__triplet import GroundTruthTripleSampler +from disent.dataset.sampling._groundtruth__walk import GroundTruthRandomWalkSampler # any dataset samplers from disent.dataset.sampling._single import SingleSampler diff --git a/disent/dataset/sampling/_base.py b/disent/dataset/sampling/_base.py index d678dc6c..5bed3bb9 100644 --- a/disent/dataset/sampling/_base.py +++ b/disent/dataset/sampling/_base.py @@ -67,7 +67,7 @@ def is_init(self) -> bool: def _sample_idx(self, idx: int) -> Tuple[int, ...]: raise NotImplementedError - def __call__(self, idx: int) -> Tuple[int, ...]: + def sample(self, idx: int) -> Tuple[int, ...]: # check that we have been initialized! if not self.is_init: raise RuntimeError(f'{self.__class__.__name__} has not been initialized! call `sampler.init(gt_data)`') @@ -79,6 +79,9 @@ def __call__(self, idx: int) -> Tuple[int, ...]: # return values return idxs + def __call__(self, idx: int) -> Tuple[int, ...]: + return self.sample(idx) + # ========================================================================= # # END # diff --git a/disent/dataset/sampling/_groundtruth__dist.py b/disent/dataset/sampling/_groundtruth__dist.py index 8f2a89d8..67c17bec 100644 --- a/disent/dataset/sampling/_groundtruth__dist.py +++ b/disent/dataset/sampling/_groundtruth__dist.py @@ -22,10 +22,21 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from fractions import Fraction +from typing import List +from typing import Optional +from typing import Union + import numpy as np from disent.dataset.data import GroundTruthData from disent.dataset.sampling._base import BaseDisentSampler +from disent.dataset.util.state_space import StateSpace + + +# ========================================================================= # +# Ground Truth Dist Sampler # +# ========================================================================= # class GroundTruthDistSampler(BaseDisentSampler): @@ -63,11 +74,11 @@ def __init__( self._sample_mode = triplet_sample_mode self._swap_chance = triplet_swap_chance # dataset variable - self._data: GroundTruthData + self._state_space: Optional[StateSpace] = None def _init(self, dataset): assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' - self._data = dataset + self._state_space = dataset.state_space_copy() # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Sampling # @@ -75,7 +86,7 @@ def _init(self, dataset): def _sample_idx(self, idx): # sample indices - indices = (idx, *np.random.randint(0, len(self._data), size=self._num_samples-1)) + indices = (idx, *np.random.randint(0, len(self._state_space), size=self._num_samples-1)) # sort based on mode if self._num_samples == 3: a_i, p_i, n_i = self._swap_triple(indices) @@ -89,31 +100,26 @@ def _sample_idx(self, idx): def _swap_triple(self, indices): a_i, p_i, n_i = indices - a_f, p_f, n_f = self._data.idx_to_pos(indices) - a_d, p_d, n_d = a_f, p_f, n_f - # dists vars - if self._scaled: - # range of positions is [0, f_size - 1], to scale between 0 and 1 we need to - # divide by (f_size - 1), but if the factor size is 1, we can't divide by zero - # so we make the minimum 1.0 - scale = np.maximum(1, np.array(self._data.factor_sizes) - 1) - a_d = a_d / scale - p_d = p_d / scale - n_d = n_d / scale + a_f, p_f, n_f = self._state_space.idx_to_pos(indices) + # get the scale for everything + # - range of positions is [0, f_size - 1], to scale between 0 and 1 we need to + # divide by (f_size - 1), but if the factor size is 1, we can't divide by zero + # so we make the minimum 1 + scale = np.maximum(1, self._state_space.factor_sizes - 1) if (self._scaled) else None + # SWAP: manhattan + if self._sample_mode == 'manhattan': + if factor_dist(a_f, p_f, scale=scale) > factor_dist(a_f, n_f, scale=scale): + return a_i, n_i, p_i # SWAP: factors - if self._sample_mode == 'factors': + elif self._sample_mode == 'factors': if factor_diff(a_f, p_f) > factor_diff(a_f, n_f): return a_i, n_i, p_i - # SWAP: manhattan - elif self._sample_mode == 'manhattan': - if factor_dist(a_d, p_d) > factor_dist(a_d, n_d): - return a_i, n_i, p_i # SWAP: combined elif self._sample_mode == 'combined': if factor_diff(a_f, p_f) > factor_diff(a_f, n_f): return a_i, n_i, p_i elif factor_diff(a_f, p_f) == factor_diff(a_f, n_f): - if factor_dist(a_d, p_d) > factor_dist(a_d, n_d): + if factor_dist(a_f, p_f, scale=scale) > factor_dist(a_f, n_f, scale=scale): return a_i, n_i, p_i # SWAP: random elif self._sample_mode != 'random': @@ -122,9 +128,114 @@ def _swap_triple(self, indices): return indices -def factor_diff(f0, f1): +def factor_diff(f0: np.ndarray, f1: np.ndarray) -> int: + # input types should be np.int64 + assert f0.dtype == f1.dtype == 'int64' + # compute distances! return np.sum(f0 != f1) -def factor_dist(f0, f1): - return np.sum(np.abs(f0 - f1)) +# NOTE: scaling here should always be the same as `disentangle_loss` +def factor_dist(f0: np.ndarray, f1: np.ndarray, scale: np.ndarray = None) -> Union[Fraction, int]: + # compute distances! + if scale is None: + # input types should all be np.int64 + assert f0.dtype == f1.dtype == 'int64', f'invalid dtypes, f0: {f0.dtype}, f1: {f1.dtype}' + # we can simply sum if everything is already an integer + return np.sum(np.abs(f0 - f1)) + else: + # input types should all be np.int64 + assert f0.dtype == f1.dtype == scale.dtype == 'int64' + # Division results in precision errors! We cannot simply sum divided values. We instead + # store values as arbitrary precision rational numbers in the form of fractions This means + # we do not lose precision while summing, and avoid comparison errors! + # - https://shlegeris.com/2018/10/23/sqrt.html + # - https://cstheory.stackexchange.com/a/4010 + # 1. first we need to convert numbers to python arbitrary precision values: + f0: List[int] = f0.tolist() + f1: List[int] = f1.tolist() + scale: List[int] = scale.tolist() + # 2. we need to sum values in the form of fractions + total = Fraction(0) + for y0, y1, s in zip(f0, f1, scale): + total += Fraction(abs(y0 - y1), s) + return total + + +# ========================================================================= # +# Investigation: # +# ========================================================================= # + + +if __name__ == '__main__': + + def main(): + from disent.dataset import DisentDataset + from disent.dataset.data import XYObjectData + from disent.dataset.data import XYObjectShadedData + from disent.dataset.data import Cars3d64Data + from disent.dataset.data import Shapes3dData + from disent.dataset.data import DSpritesData + from disent.dataset.data import SmallNorb64Data + from disent.util.seeds import TempNumpySeed + from tqdm import tqdm + + repeats = 1000 + samples = 100 + + # RESULTS - manhattan: + # cars3d: orig_vs_divs=30.066%, orig_vs_frac=30.066%, divs_vs_frac=0.000% + # 3dshapes: orig_vs_divs=12.902%, orig_vs_frac=12.878%, divs_vs_frac=0.096% + # dsprites: orig_vs_divs=24.035%, orig_vs_frac=24.032%, divs_vs_frac=0.003% + # smallnorb: orig_vs_divs=18.601%, orig_vs_frac=18.598%, divs_vs_frac=0.005% + # xy_squares_minimal: orig_vs_divs= 1.389%, orig_vs_frac= 0.000%, divs_vs_frac=1.389% + # xy_object: orig_vs_divs=15.520%, orig_vs_frac=15.511%, divs_vs_frac=0.029% + # xy_object: orig_vs_divs=23.973%, orig_vs_frac=23.957%, divs_vs_frac=0.082% + # RESULTS - combined: + # cars3d: orig_vs_divs=15.428%, orig_vs_frac=15.428%, divs_vs_frac=0.000% + # 3dshapes: orig_vs_divs=4.982%, orig_vs_frac= 4.968%, divs_vs_frac=0.050% + # dsprites: orig_vs_divs=8.366%, orig_vs_frac= 8.363%, divs_vs_frac=0.003% + # smallnorb: orig_vs_divs=7.359%, orig_vs_frac= 7.359%, divs_vs_frac=0.000% + # xy_squares_minimal: orig_vs_divs=0.610%, orig_vs_frac= 0.000%, divs_vs_frac=0.610% + # xy_object: orig_vs_divs=7.622%, orig_vs_frac= 7.614%, divs_vs_frac=0.020% + # xy_object: orig_vs_divs=8.741%, orig_vs_frac= 8.733%, divs_vs_frac=0.046% + for mode in ['manhattan', 'combined']: + for data_cls in [ + Cars3d64Data, + Shapes3dData, + DSpritesData, + SmallNorb64Data, + XYObjectData, + XYObjectShadedData, + ]: + data = data_cls() + dataset_orig = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}')) + dataset_frac = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}_scaled')) + dataset_divs = DisentDataset(data, sampler=GroundTruthDistSampler(3, f'{mode}_scaled_INVALID')) + # calculate the average number of mismatches between sampling methods! + all_wrong_frac = [] # frac vs orig + all_wrong_divs = [] # divs vs orig + all_wrong_diff = [] # frac vs divs + with TempNumpySeed(777): + progress = tqdm(range(repeats), desc=f'{mode} {data.name}') + for i in progress: + batch_seed = np.random.randint(0, 2**32) + with TempNumpySeed(batch_seed): idxs_orig = np.array([dataset_orig.sampler.sample(np.random.randint(0, len(dataset_orig))) for _ in range(samples)]) + with TempNumpySeed(batch_seed): idxs_frac = np.array([dataset_frac.sampler.sample(np.random.randint(0, len(dataset_frac))) for _ in range(samples)]) + with TempNumpySeed(batch_seed): idxs_divs = np.array([dataset_divs.sampler.sample(np.random.randint(0, len(dataset_divs))) for _ in range(samples)]) + # check number of miss_matches + all_wrong_frac.append(np.sum(np.any(idxs_orig != idxs_frac, axis=-1)) / samples * 100) + all_wrong_divs.append(np.sum(np.any(idxs_orig != idxs_divs, axis=-1)) / samples * 100) + all_wrong_diff.append(np.sum(np.any(idxs_frac != idxs_divs, axis=-1)) / samples * 100) + # update progress bar + progress.set_postfix({ + 'orig_vs_divs': f'{np.mean(all_wrong_divs):5.3f}%', + 'orig_vs_frac': f'{np.mean(all_wrong_frac):5.3f}%', + 'divs_vs_frac': f'{np.mean(all_wrong_diff):5.3f}%', + }) + main() + + +# ========================================================================= # +# END: # +# ========================================================================= # diff --git a/disent/dataset/sampling/_groundtruth__pair.py b/disent/dataset/sampling/_groundtruth__pair.py index 753bc1aa..9b7dcc32 100644 --- a/disent/dataset/sampling/_groundtruth__pair.py +++ b/disent/dataset/sampling/_groundtruth__pair.py @@ -22,10 +22,13 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import Optional + import numpy as np from disent.dataset.data import GroundTruthData from disent.dataset.sampling._base import BaseDisentSampler from disent.dataset.sampling._groundtruth__triplet import normalise_range_pair, FactorSizeError +from disent.dataset.util.state_space import StateSpace from disent.util.math.random import sample_radius @@ -60,15 +63,15 @@ def __init__( self.p_k_range = p_k_range self.p_radius_range = p_radius_range # dataset variable - self._data: GroundTruthData + self._state_space: Optional[StateSpace] def _init(self, dataset): assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' - self._data = dataset + self._state_space = dataset.state_space_copy() # DIFFERING FACTORS - self.p_k_min, self.p_k_max = self._min_max_from_range(p_range=self.p_k_range, max_values=self._data.num_factors) + self.p_k_min, self.p_k_max = self._min_max_from_range(p_range=self.p_k_range, max_values=self._state_space.num_factors) # RADIUS SAMPLING - self.p_radius_min, self.p_radius_max = self._min_max_from_range(p_range=self.p_radius_range, max_values=self._data.factor_sizes) + self.p_radius_min, self.p_radius_max = self._min_max_from_range(p_range=self.p_radius_range, max_values=self._state_space.factor_sizes) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # CORE # @@ -77,8 +80,8 @@ def _init(self, dataset): def _sample_idx(self, idx): f0, f1 = self.datapoint_sample_factors_pair(idx) return ( - self._data.pos_to_idx(f0), - self._data.pos_to_idx(f1), + self._state_space.pos_to_idx(f0), + self._state_space.pos_to_idx(f1), ) def datapoint_sample_factors_pair(self, idx): @@ -102,7 +105,7 @@ def datapoint_sample_factors_pair(self, idx): p_k = self._sample_num_factors() p_shared_indices = self._sample_shared_indices(p_k) # SAMPLE FACTORS - sample, resample and replace shared factors with originals - anchor_factors = self._data.idx_to_pos(idx) + anchor_factors = self._state_space.idx_to_pos(idx) positive_factors = self._resample_factors(anchor_factors) positive_factors[p_shared_indices] = anchor_factors[p_shared_indices] return anchor_factors, positive_factors @@ -125,11 +128,11 @@ def _sample_num_factors(self): return p_k def _sample_shared_indices(self, p_k): - p_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-p_k, replace=False) + p_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-p_k, replace=False) return p_shared_indices def _resample_factors(self, anchor_factors): - positive_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1) + positive_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1) return positive_factors diff --git a/disent/dataset/sampling/_groundtruth__pair_orig.py b/disent/dataset/sampling/_groundtruth__pair_orig.py index 02530fb0..76ec9a91 100644 --- a/disent/dataset/sampling/_groundtruth__pair_orig.py +++ b/disent/dataset/sampling/_groundtruth__pair_orig.py @@ -22,9 +22,12 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import Optional + import numpy as np from disent.dataset.data import GroundTruthData from disent.dataset.sampling._base import BaseDisentSampler +from disent.dataset.util.state_space import StateSpace class GroundTruthPairOrigSampler(BaseDisentSampler): @@ -47,11 +50,11 @@ def __init__( # DIFFERING FACTORS self.p_k = p_k # dataset variable - self._data: GroundTruthData + self._state_space: Optional[StateSpace] = None def _init(self, dataset): assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' - self._data = dataset + self._state_space = dataset.state_space_copy() # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # CORE # @@ -60,8 +63,8 @@ def _init(self, dataset): def _sample_idx(self, idx): f0, f1 = self.datapoint_sample_factors_pair(idx) return ( - self._data.pos_to_idx(f0), - self._data.pos_to_idx(f1), + self._state_space.pos_to_idx(f0), + self._state_space.pos_to_idx(f1), ) def datapoint_sample_factors_pair(self, idx): @@ -70,14 +73,14 @@ def datapoint_sample_factors_pair(self, idx): Except deterministic for the first item in the pair, based off of idx. """ # randomly sample the first observation -- In our case we just use the idx - sampled_factors = self._data.idx_to_pos(idx) + sampled_factors = self._state_space.idx_to_pos(idx) # sample the next observation with k differing factors - next_factors, k = _sample_k_differing(sampled_factors, self._data, k=self.p_k) + next_factors, k = _sample_k_differing(sampled_factors, self._state_space, k=self.p_k) # return the samples return sampled_factors, next_factors -def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1): +def _sample_k_differing(factors, state_space: StateSpace, k=1): """ Resample the factors used for the corresponding item in a pair. - Based on simple_dynamics() from: @@ -88,7 +91,7 @@ def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1): assert factors.ndim == 1 # sample k if k <= 0: - k = np.random.randint(1, ground_truth_data.num_factors) + k = np.random.randint(1, state_space.num_factors) # randomly choose 1 or k # TODO: This is in disentanglement lib, HOWEVER is this not a mistake? # A bug report has been submitted to disentanglement_lib for clarity: @@ -98,20 +101,20 @@ def _sample_k_differing(factors, ground_truth_data: GroundTruthData, k=1): index_list = np.random.choice(len(factors), k, replace=False) # randomly update factors for index in index_list: - factors[index] = np.random.choice(ground_truth_data.factor_sizes[index]) + factors[index] = np.random.choice(state_space.factor_sizes[index]) # return! return factors, k -def _sample_weak_pair_factors(gt_data: GroundTruthData): # pragma: no cover +def _sample_weak_pair_factors(state_space: StateSpace): # pragma: no cover """ Sample a weakly supervised pair from the given GroundTruthData. - Based on weak_dataset_generator() from: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/methods/weak/train_weak_lib.py """ # randomly sample the first observation - sampled_factors = gt_data.sample_factors(1) + sampled_factors = state_space.sample_factors(1) # sample the next observation with k differing factors - next_factors, k = _sample_k_differing(sampled_factors, gt_data, k=1) + next_factors, k = _sample_k_differing(sampled_factors, state_space, k=1) # return the samples return sampled_factors, next_factors diff --git a/disent/dataset/sampling/_groundtruth__single.py b/disent/dataset/sampling/_groundtruth__single.py index 885bc5ca..8198e9b0 100644 --- a/disent/dataset/sampling/_groundtruth__single.py +++ b/disent/dataset/sampling/_groundtruth__single.py @@ -23,8 +23,11 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional + from disent.dataset.data import GroundTruthData from disent.dataset.sampling._base import BaseDisentSampler +from disent.dataset.util.state_space import StateSpace log = logging.getLogger(__name__) @@ -42,10 +45,11 @@ def uninit_copy(self) -> 'GroundTruthSingleSampler': def __init__(self): super().__init__(num_samples=1) + self._state_space: Optional[StateSpace] = None # TODO: not actually needed def _init(self, dataset): assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' - self._data = dataset + self._state_space = dataset.state_space_copy() def _sample_idx(self, idx): return (idx,) diff --git a/disent/dataset/sampling/_groundtruth__triplet.py b/disent/dataset/sampling/_groundtruth__triplet.py index 231221a4..e87c6be7 100644 --- a/disent/dataset/sampling/_groundtruth__triplet.py +++ b/disent/dataset/sampling/_groundtruth__triplet.py @@ -23,12 +23,14 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional from typing import Tuple from typing import Union import numpy as np from disent.dataset.data import GroundTruthData from disent.dataset.sampling._base import BaseDisentSampler +from disent.dataset.util.state_space import StateSpace from disent.util.math.random import sample_radius @@ -90,16 +92,16 @@ def __init__( if swap_chance is not None: assert 0 <= swap_chance <= 1, f'{swap_chance=} must be in range 0 to 1.' # dataset variable - self._data: GroundTruthData + self._state_space: Optional[StateSpace] def _init(self, dataset): assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' - self._data = dataset + self._state_space = dataset.state_space_copy() # DIFFERING FACTORS self.p_k_min, self.p_k_max, self.n_k_min, self.n_k_max = self._min_max_from_range( p_range=self.p_k_range, n_range=self.n_k_range, - max_values=self._data.num_factors, + max_values=self._state_space.num_factors, n_sample_mode=self.n_k_sample_mode, is_radius=False ) @@ -107,7 +109,7 @@ def _init(self, dataset): self.p_radius_min, self.p_radius_max, self.n_radius_min, self.n_radius_max = self._min_max_from_range( p_range=self.p_radius_range, n_range=self.n_radius_range, - max_values=self._data.factor_sizes, + max_values=self._state_space.factor_sizes, n_sample_mode=self.n_radius_sample_mode, is_radius=True ) @@ -119,9 +121,9 @@ def _init(self, dataset): def _sample_idx(self, idx): f0, f1, f2 = self.datapoint_sample_factors_triplet(idx) return ( - self._data.pos_to_idx(f0), - self._data.pos_to_idx(f1), - self._data.pos_to_idx(f2), + self._state_space.pos_to_idx(f0), + self._state_space.pos_to_idx(f1), + self._state_space.pos_to_idx(f2), ) def datapoint_sample_factors_triplet(self, idx): @@ -129,7 +131,7 @@ def datapoint_sample_factors_triplet(self, idx): p_k, n_k = self._sample_num_factors() p_shared_indices, n_shared_indices = self._sample_shared_indices(p_k, n_k) # SAMPLE FACTORS - sample, resample and replace shared factors with originals - anchor_factors = self._data.idx_to_pos(idx) + anchor_factors = self._state_space.idx_to_pos(idx) positive_factors, negative_factors = self._resample_factors(anchor_factors) positive_factors[p_shared_indices] = anchor_factors[p_shared_indices] negative_factors[n_shared_indices] = anchor_factors[n_shared_indices] @@ -189,7 +191,7 @@ def _sample_num_factors(self): p_k = np.random.randint(self.p_k_min, self.p_k_max + 1) # sample for negative if self.n_k_sample_mode == 'offset': - n_k = np.random.randint(p_k + self.n_k_min, min(p_k + self.n_k_max, self._data.num_factors) + 1) + n_k = np.random.randint(p_k + self.n_k_min, min(p_k + self.n_k_max, self._state_space.num_factors) + 1) elif self.n_k_sample_mode == 'bounded_below': n_k = np.random.randint(max(p_k, self.n_k_min), self.n_k_max + 1) elif self.n_k_sample_mode == 'random': @@ -200,18 +202,18 @@ def _sample_num_factors(self): return p_k, n_k def _sample_shared_indices(self, p_k, n_k): - p_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-p_k, replace=False) + p_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-p_k, replace=False) # sample for negative if self.n_k_is_shared: - n_shared_indices = p_shared_indices[:self._data.num_factors-n_k] + n_shared_indices = p_shared_indices[:self._state_space.num_factors-n_k] else: - n_shared_indices = np.random.choice(self._data.num_factors, size=self._data.num_factors-n_k, replace=False) + n_shared_indices = np.random.choice(self._state_space.num_factors, size=self._state_space.num_factors-n_k, replace=False) # we're done! return p_shared_indices, n_shared_indices def _resample_factors(self, anchor_factors): # sample positive - positive_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1) + positive_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=self.p_radius_min, r_high=self.p_radius_max + 1) # negative arguments if self.n_radius_sample_mode == 'offset': sampled_radius = np.abs(anchor_factors - positive_factors) @@ -227,7 +229,7 @@ def _resample_factors(self, anchor_factors): else: raise KeyError(f'Unknown mode: {self.n_radius_sample_mode=}') # sample negative - negative_factors = sample_radius(anchor_factors, low=0, high=self._data.factor_sizes, r_low=n_r_low, r_high=n_r_high) + negative_factors = sample_radius(anchor_factors, low=0, high=self._state_space.factor_sizes, r_low=n_r_low, r_high=n_r_high) # we're done! return positive_factors, negative_factors @@ -239,14 +241,14 @@ def _swap_factors(self, anchor_factors, positive_factors, negative_factors): p_dist = np.sum(np.abs(anchor_factors - positive_factors)) n_dist = np.sum(np.abs(anchor_factors - negative_factors)) elif self._swap_metric == 'manhattan_norm': - p_dist = np.sum(np.abs((anchor_factors - positive_factors) / np.subtract(self._data.factor_sizes, 1))) - n_dist = np.sum(np.abs((anchor_factors - negative_factors) / np.subtract(self._data.factor_sizes, 1))) + p_dist = np.sum(np.abs((anchor_factors - positive_factors) / np.subtract(self._state_space.factor_sizes, 1))) + n_dist = np.sum(np.abs((anchor_factors - negative_factors) / np.subtract(self._state_space.factor_sizes, 1))) elif self._swap_metric == 'euclidean': p_dist = np.linalg.norm(anchor_factors - positive_factors) n_dist = np.linalg.norm(anchor_factors - negative_factors) elif self._swap_metric == 'euclidean_norm': - p_dist = np.linalg.norm((anchor_factors - positive_factors) / np.subtract(self._data.factor_sizes, 1)) - n_dist = np.linalg.norm((anchor_factors - negative_factors) / np.subtract(self._data.factor_sizes, 1)) + p_dist = np.linalg.norm((anchor_factors - positive_factors) / np.subtract(self._state_space.factor_sizes, 1)) + n_dist = np.linalg.norm((anchor_factors - negative_factors) / np.subtract(self._state_space.factor_sizes, 1)) else: raise KeyError # perform swap diff --git a/disent/dataset/sampling/_groundtruth__walk.py b/disent/dataset/sampling/_groundtruth__walk.py new file mode 100644 index 00000000..37f5c297 --- /dev/null +++ b/disent/dataset/sampling/_groundtruth__walk.py @@ -0,0 +1,128 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple + +import numpy as np + +from disent.dataset.data import GroundTruthData +from disent.dataset.sampling import BaseDisentSampler +from disent.dataset.util.state_space import StateSpace +from disent.util.jit import try_njit + + +# ========================================================================= # +# Pretend We Are Walking Ground-Truth Factors Randomly # +# ========================================================================= # + + +class GroundTruthRandomWalkSampler(BaseDisentSampler): + + def uninit_copy(self) -> 'GroundTruthRandomWalkSampler': + return GroundTruthRandomWalkSampler( + num_samples=self._num_samples, + p_dist_max=self._p_dist_max, + n_dist_max=self._n_dist_max, + ) + + def __init__( + self, + num_samples: int = 3, + p_dist_max: int = 8, + n_dist_max: int = 32, + ): + super().__init__(num_samples=num_samples) + # checks + assert num_samples in {1, 2, 3}, f'num_samples ({repr(num_samples)}) must be 1, 2 or 3' + # save hparams + self._num_samples = num_samples + self._p_dist_max = p_dist_max + self._n_dist_max = n_dist_max + # dataset variable + self._state_space: Optional[StateSpace] = None + + def _init(self, dataset: GroundTruthData): + assert isinstance(dataset, GroundTruthData), f'dataset must be an instance of {repr(GroundTruthData.__class__.__name__)}, got: {repr(dataset)}' + self._state_space = dataset.state_space_copy() + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # Sampling # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + def _sample_idx(self, idx) -> Tuple[int, ...]: + if self._num_samples == 1: + return (idx,) + elif self._num_samples == 2: + p_dist = np.random.randint(1, self._p_dist_max + 1) + pos = _random_walk(idx, p_dist, self._state_space.factor_sizes) + return (idx, pos) + elif self._num_samples == 3: + p_dist = np.random.randint(1, self._p_dist_max + 1) + n_dist = np.random.randint(1, self._n_dist_max + 1) + pos = _random_walk(idx, p_dist, self._state_space.factor_sizes) + neg = _random_walk(pos, n_dist, self._state_space.factor_sizes) + return (idx, pos, neg) + else: + raise RuntimeError + + +# ========================================================================= # +# Helper # +# ========================================================================= # + + +def _random_walk(idx: int, dist: int, factor_sizes: np.ndarray) -> int: + # random walk + pos = np.array(np.unravel_index(idx, factor_sizes), dtype=int) # much faster than StateSpace.idx_to_pos, we don't need checks! + for _ in range(dist): + _walk_nearby_inplace(pos, factor_sizes) + idx = np.ravel_multi_index(pos, factor_sizes) # much faster than StateSpace.pos_to_idx, we don't need checks! + # done! + return int(idx) + + +@try_njit() +def _walk_nearby_inplace(pos: np.ndarray, factor_sizes: Sequence[int]) -> NoReturn: + # try to shift any single factor by 1 or -1 + while True: + f_idx = np.random.randint(0, len(factor_sizes)) + cur = pos[f_idx] + # walk random factor value + if np.random.random() < 0.5: + nxt = max(cur - 1, 0) + else: + nxt = min(cur + 1, factor_sizes[f_idx] - 1) + # exit if different + if cur != nxt: + break + # update the position + pos[f_idx] = nxt + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/dataset/transform/_augment.py b/disent/dataset/transform/_augment.py index 69082a1d..7cfb9748 100644 --- a/disent/dataset/transform/_augment.py +++ b/disent/dataset/transform/_augment.py @@ -24,6 +24,7 @@ import os import re +import warnings from numbers import Number from typing import List from typing import Tuple @@ -32,12 +33,13 @@ import numpy as np import torch -import disent from disent.nn.modules import DisentModule from disent.nn.functional import torch_box_kernel_2d from disent.nn.functional import torch_conv2d_channel_wise_fft from disent.nn.functional import torch_gaussian_kernel_2d +import disent.registry as R + # ========================================================================= # # Transforms # @@ -178,16 +180,22 @@ def _make_kernel(self, shape, device): # ========================================================================= # +_NO_ARG = object() + + class FftKernel(DisentModule): """ 2D Convolve an image """ - def __init__(self, kernel: Union[torch.Tensor, str], normalize: bool = True): + def __init__(self, kernel: Union[torch.Tensor, str], normalize_mode: str = _NO_ARG): super().__init__() + # deprecation error + if normalize_mode is _NO_ARG: + raise ValueError(f'default argument for normalize_mode was "sum", this has been deprecated and will change to "none" in future. Please manually override this value!') # load & save the kernel -- no gradients allowed self._kernel: torch.Tensor - self.register_buffer('_kernel', get_kernel(kernel, normalize=normalize), persistent=True) + self.register_buffer('_kernel', get_kernel(kernel, normalize_mode=normalize_mode), persistent=True) self._kernel.requires_grad = False def forward(self, obs): @@ -209,11 +217,30 @@ def forward(self, obs): # ========================================================================= # -def _normalise_kernel(kernel: torch.Tensor, normalize: bool) -> torch.Tensor: - if normalize: - with torch.no_grad(): - return kernel / kernel.sum() - return kernel + +@torch.no_grad() +def _scale_kernel(kernel: torch.Tensor, mode: Union[bool, str] = 'abssum'): + # old normalize mode + if isinstance(mode, bool): + raise ValueError(f'boolean arguments to `scale_kernel` are deprecated, convert True to "sum" and False to "none", got: {repr(mode)}') + # handle the normalize mode + if mode == 'sum': + return kernel / kernel.sum() + elif mode == 'abssum': + return kernel / torch.abs(kernel).sum() + elif mode == 'possum': + return kernel / torch.abs(kernel)[kernel > 0].sum() + elif mode == 'negsum': + return kernel / torch.abs(kernel)[kernel < 0].sum() + elif mode == 'maxsum': + return kernel / torch.maximum( + torch.abs(kernel)[kernel > 0].sum(), + torch.abs(kernel)[kernel < 0].sum(), + ) + elif mode == 'none': + return kernel + else: + raise KeyError(f'invalid scale mode: {repr(mode)}') def _check_kernel(kernel: torch.Tensor) -> torch.Tensor: @@ -227,39 +254,10 @@ def _check_kernel(kernel: torch.Tensor) -> torch.Tensor: return kernel -_KERNELS = { - # kernels that do not require arguments, just general factory functions - # name: class/fn -- with no required args -} - - -_ARG_KERNELS = [ - # (REGEX, EXAMPLE, FACTORY_FUNC) - # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group - # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first. - (re.compile(r'^(box)_r(\d+)$'), 'box_r31', lambda kern, radius: torch_box_kernel_2d(radius=int(radius))[None, ...]), - (re.compile(r'^(gau)_r(\d+)$'), 'gau_r31', lambda kern, radius: torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...]), -] - - # NOTE: this function compliments make_reconstruction_loss in frameworks/helper/reconstructions.py -def _make_kernel(name: str) -> torch.Tensor: - if name in _KERNELS: - # search normal losses! - return _KERNELS[name]() - else: - # regex search kernels, and call with args! - for r, _, fn in _ARG_KERNELS: - result = r.search(name) - if result is not None: - return fn(*result.groups()) - # we couldn't find anything - raise KeyError(f'Invalid kernel name: {repr(name)} Examples of argument based kernels include: {[example for _, example, _ in _ARG_KERNELS]}') - - -def make_kernel(name: str, normalize: bool = False): - kernel = _make_kernel(name) - kernel = _normalise_kernel(kernel, normalize=normalize) +def make_kernel(name: str, normalize_mode: str = 'none'): + kernel = R.KERNELS[name] + kernel = _scale_kernel(kernel, mode=normalize_mode) kernel = _check_kernel(kernel) return kernel @@ -267,21 +265,36 @@ def make_kernel(name: str, normalize: bool = False): def _get_kernel(name_or_path: str) -> torch.Tensor: if '/' not in name_or_path: try: - return _make_kernel(name_or_path) + return R.KERNELS[name_or_path] except KeyError: pass if os.path.isfile(name_or_path): return torch.load(name_or_path) - raise KeyError(f'Invalid kernel path or name: {repr(name_or_path)} Examples of argument based kernels include: {[example for _, example, _ in _ARG_KERNELS]}, otherwise specify a valid path to a kernel file save with torch.') + raise KeyError(f'Invalid kernel path or name: {repr(name_or_path)} Examples of argument based kernels include: {R.KERNELS.regex_examples}, otherwise specify a valid path to a kernel file save with torch.') -def get_kernel(kernel: Union[str, torch.Tensor], normalize: bool = False): +def get_kernel(kernel: Union[str, torch.Tensor], normalize_mode: str = 'none'): kernel = _get_kernel(kernel) if isinstance(kernel, str) else torch.clone(kernel) - kernel = _normalise_kernel(kernel, normalize=normalize) + kernel = _scale_kernel(kernel, mode=normalize_mode) kernel = _check_kernel(kernel) return kernel +# ========================================================================= # +# Registered Kernels # +# ========================================================================= # + + +# we register this in disent.registry +def _make_box_kernel(radius: str): + return torch_box_kernel_2d(radius=int(radius))[None, ...] + + +# we register this in disent.registry +def _make_gaussian_kernel(radius: str): + return torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...] + + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/dataset/transform/_augment_disent.py b/disent/dataset/transform/_augment_disent.py index 5cc305f0..0ed03100 100644 --- a/disent/dataset/transform/_augment_disent.py +++ b/disent/dataset/transform/_augment_disent.py @@ -22,6 +22,10 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import Callable +from typing import Optional +import torch + # ========================================================================= # # Augment # @@ -34,7 +38,11 @@ class DisentDatasetTransform(object): datasets from: disent.dataset.groundtruth """ - def __init__(self, transform=None, transform_targ=None): + def __init__( + self, + transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + transform_targ: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ): self.transform = transform self.transform_targ = transform_targ diff --git a/disent/dataset/transform/functional.py b/disent/dataset/transform/functional.py index e3ebf447..6bc24ffa 100644 --- a/disent/dataset/transform/functional.py +++ b/disent/dataset/transform/functional.py @@ -88,6 +88,12 @@ def check_tensor( # ========================================================================= # +def _is_size_different(obs: Obs, size: SizeType): + h, w = (size, size) if isinstance(size, int) else size + H, W = (obs.height, obs.width) if isinstance(obs, Image) else obs.shape[:2] + return (H != h) or (W != w) + + def to_img_tensor_u8( obs: Obs, size: Optional[SizeType] = None, @@ -101,7 +107,7 @@ def to_img_tensor_u8( 3. move channels to first dim (H, W, C) -> (C, H, W) """ # resize image - if size is not None: + if (size is not None) and _is_size_different(obs, size): if not isinstance(obs, Image): obs = F_tv.to_pil_image(obs) obs = F_tv.resize(obs, size=size) @@ -139,11 +145,27 @@ def to_img_tensor_f32( 5. normalize using mean and std, values might thus be outside of the range [0, 1] """ # resize image - if size is not None: + if (size is not None) and _is_size_different(obs, size): if not isinstance(obs, Image): obs = F_tv.to_pil_image(obs) obs = F_tv.resize(obs, size=size) # transform to tensor, add missing dims & move channel dim to front + # TODO: this should be replaced with custom logic, this is quite slow... + # - benchmarks show that doing conversions as numpy first, and then using torch.from_numpy is faster! + # - eg. SmallNorb64Data + # `to_tensor(item) # 27547.81it/s + # `torch.from_numpy(np.transpose(item.astype('float32') / 255, [2, 0, 1])) # 53987.90it/s + # `torch.from_numpy((item.astype('float32') / 255).transpose([2, 0, 1])) # 66544.00it/s + # `torch.from_numpy(item.astype('float32').transpose([2, 0, 1]) / 255) # 66511.62it/s + # `torch.from_numpy(item.transpose([2, 0, 1]).astype('float32') / 255) # 66133.27it/s + # - eg. Cars3d64Data + # `to_tensor(item) # 13810.46it/s + # `torch.from_numpy(np.transpose(item.astype('float32') / 255, [2, 0, 1])) # 32258.03it/s + # `torch.from_numpy((item.astype('float32') / 255).transpose([2, 0, 1])) # 37861.46it/s + # `torch.from_numpy(item.astype('float32').transpose([2, 0, 1]) / 255) # 33034.21it/s + # `torch.from_numpy(item.transpose([2, 0, 1]).astype('float32') / 255) # 32883.32it/s + # - INVESTIGATE: if transpose is used, and then from_numpy is called, that references the original memory? It + # might then be slower to convolve this data? Speed benefits could be negated? A copy might be better? obs = F_tv.to_tensor(obs) # checks assert obs.ndim == 3, f'obs has does not have 3 dimensions, got: {obs.ndim} for shape: {obs.shape}' diff --git a/disent/dataset/util/datafile.py b/disent/dataset/util/datafile.py index 7e00b892..3c8430b5 100644 --- a/disent/dataset/util/datafile.py +++ b/disent/dataset/util/datafile.py @@ -27,6 +27,7 @@ from typing import Callable from typing import Dict from typing import final +from typing import NoReturn from typing import Optional from typing import Sequence from typing import Tuple @@ -95,7 +96,7 @@ def wrapped(out_file): self._prepare(out_dir=out_dir, out_file=out_file) return wrapped() - def _prepare(self, out_dir: str, out_file: str) -> str: + def _prepare(self, out_dir: str, out_file: str) -> NoReturn: # TODO: maybe raise a FileNotFoundError or a HashError instead? raise NotImplementedError diff --git a/disent/dataset/util/hdf5.py b/disent/dataset/util/hdf5.py index eec1058b..0d4f17fe 100644 --- a/disent/dataset/util/hdf5.py +++ b/disent/dataset/util/hdf5.py @@ -385,6 +385,7 @@ def add_dataset_from_gt_data( from disent.dataset import DisentDataset from disent.dataset.data import GroundTruthData # get dataset + # TODO: we should not automatically handle this extraction... The transform could be missing on the dataset? if isinstance(data, DisentDataset): gt_data = data.gt_data elif isinstance(data, GroundTruthData): gt_data = data else: raise TypeError(f'invalid data type: {type(data)}, must be {DisentDataset} or {GroundTruthData}') diff --git a/disent/dataset/util/npz.py b/disent/dataset/util/npz.py new file mode 100644 index 00000000..c7e83995 --- /dev/null +++ b/disent/dataset/util/npz.py @@ -0,0 +1,77 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +from tqdm import tqdm +from disent.util.inout.files import AtomicSaveFile + + +# ========================================================================= # +# Save Numpy Files # +# ========================================================================= # + + +def save_dataset_array(array: np.ndarray, out_file: str, overwrite: bool = False, save_key: str = 'images'): + assert array.ndim == 4, f'invalid array shape, got: {array.shape}, must be: (N, H, W, C)' + assert array.dtype == 'uint8', f'invalid array dtype, got: {array.dtype}, must be: "uint8"' + # save the data + with AtomicSaveFile(out_file, overwrite=overwrite) as temp_file: + np.savez_compressed(temp_file, **{save_key: array}) + + +def save_resized_dataset_array(array: np.ndarray, out_file: str, size: int = 64, overwrite: bool = False, save_key: str = 'images', progress: bool = True): + import torchvision.transforms.functional as F_tv + from disent.dataset.data import ArrayDataset + # checks + assert out_file.endswith('.npz'), f'The output file must end with the extension: ".npz", got: {repr(out_file)}' + # Get the transform -- copied from: ToImgTensorF32 / ToImgTensorU8 + def transform(obs): + H, W, C = obs.shape + obs = F_tv.to_pil_image(obs) + obs = F_tv.resize(obs, size=[size, size]) + obs = np.array(obs) + # add removed dimension! + if obs.ndim == 2: + obs = obs[:, :, None] + assert obs.shape == (size, size, C) + return obs + # load the converted cars3d data ?x128x128x3 + assert array.ndim == 4, f'invalid array shape, got: {array.shape}, must be: (N, H, W, C)' + assert array.dtype == 'uint8', f'invalid array dtype, got: {array.dtype}, must be: "uint8"' + N, H, W, C = array.shape + data = ArrayDataset(array, transform=transform) + # save the data + with AtomicSaveFile(out_file, overwrite=overwrite) as temp_file: + # resize the cars3d data + idxs = tqdm(range(N), desc = 'converting') if progress else range(N) + converted = np.zeros([N, size, size, C], dtype='uint8') + for i in idxs: + converted[i, ...] = data[i] + # save the data + np.savez_compressed(temp_file, **{save_key: converted}) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/dataset/util/state_space.py b/disent/dataset/util/state_space.py index 5c72168f..e68d0d90 100644 --- a/disent/dataset/util/state_space.py +++ b/disent/dataset/util/state_space.py @@ -54,9 +54,12 @@ class StateSpace(LengthIter): def __init__(self, factor_sizes: Sequence[int], factor_names: Optional[Sequence[str]] = None): super().__init__() - # dimension + # dimension: [read only] self.__factor_sizes = np.array(factor_sizes) self.__factor_sizes.flags.writeable = False + # multipliers: [read only] + self.__factor_multipliers = _dims_multipliers(self.__factor_sizes) + self.__factor_multipliers.flags.writeable = False # total permutations self.__size = int(np.prod(factor_sizes)) # factor names @@ -96,6 +99,21 @@ def factor_names(self) -> Tuple[str, ...]: """A list of names of factors handled by this state space""" return self.__factor_names + @property + def factor_multipliers(self) -> np.ndarray: + """ + The cumulative product of the factor_sizes used to convert indices to positions, and positions to indices. + - The highest values is at the front, the lowest is at the end always being 1. + - The size of this vector is: num_factors + 1 + + Formulas: + * Use broadcasting to get positions: + pos = (idx[..., None] % muls[:-1]) // muls[1:] + * Use broadcasting to get indices + idx = np.sum(pos * muls[1:], axis=-1) + """ + return self.__factor_multipliers + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Factor Helpers # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -137,6 +155,8 @@ def pos_to_idx(self, positions) -> np.ndarray: Convert a position to an index (or convert a list of positions to a list of indices) - positions are lists of integers, with each element < their corresponding factor size - indices are integers < size + + TODO: can factor_multipliers be used to speed this up? """ positions = np.moveaxis(positions, source=-1, destination=0) return np.ravel_multi_index(positions, self.__factor_sizes) @@ -146,6 +166,8 @@ def idx_to_pos(self, indices) -> np.ndarray: Convert an index to a position (or convert a list of indices to a list of positions) - indices are integers < size - positions are lists of integers, with each element < their corresponding factor size + + TODO: can factor_multipliers be used to speed this up? """ positions = np.array(np.unravel_index(indices, self.__factor_sizes)) return np.moveaxis(positions, source=0, destination=-1) @@ -249,14 +271,14 @@ def _get_f_idx_and_factors_and_size(self, f_idx: int = None, base_factors=None, # return everything return f_idx, base_factors, num - def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode='interval') -> np.ndarray: + def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode: str = 'interval', start_index: int = 0) -> np.ndarray: """ Sample a single random factor traversal along the given factor index, starting from some random base sample. """ f_idx, base_factors, num = self._get_f_idx_and_factors_and_size(f_idx=f_idx, base_factors=base_factors, num=num) # generate traversal - base_factors[:, f_idx] = get_idx_traversal(self.factor_sizes[f_idx], num_frames=num, mode=mode) + base_factors[:, f_idx] = get_idx_traversal(self.factor_sizes[f_idx], num_frames=num, mode=mode, start_index=start_index) # return factors (num_frames, num_factors) return base_factors @@ -267,7 +289,7 @@ def sample_random_factor_traversal(self, f_idx: int = None, base_factors=None, n @lru_cache() -def _get_step_size(factor_sizes, f_idx): +def _get_step_size(factor_sizes, f_idx: int): # check values assert f_idx >= 0 assert f_idx < len(factor_sizes) @@ -279,6 +301,20 @@ def _get_step_size(factor_sizes, f_idx): return int(np.ravel_multi_index(pos, factor_sizes)) +def _dims_multipliers(factor_sizes: np.ndarray) -> np.ndarray: + factor_sizes = np.array(factor_sizes) + assert factor_sizes.ndim == 1 + return np.append(np.cumprod(factor_sizes[::-1])[::-1], 1) + + +# @try_njit +# def _idx_to_pos(idxs, dims_mul): +# factors = np.expand_dims(np.array(idxs, dtype='int'), axis=-1) +# factors = factors % dims_mul[:-1] +# factors //= dims_mul[1:] +# return factors + + # ========================================================================= # # Hidden State Space # # ========================================================================= # diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py index 08be5c04..0ba2985f 100644 --- a/disent/dataset/util/stats.py +++ b/disent/dataset/util/stats.py @@ -29,8 +29,6 @@ import torch from torch.utils.data import DataLoader -from disent.util.function import wrapped_partial - # ========================================================================= # # COMPUTE DATASET STATS # @@ -86,15 +84,17 @@ def compute_data_mean_std( if __name__ == '__main__': - def main(progress=False): + def main(progress=True, num_workers=0, batch_size=256): # try changing workers to zero on MacOS from disent.dataset import data from disent.dataset.transform import ToImgTensorF32 for data_cls in [ # groundtruth -- impl data.Cars3dData, - data.DSpritesData, + data.Cars3d64Data, data.SmallNorbData, + data.SmallNorb64Data, + data.DSpritesData, data.Shapes3dData, # groundtruth -- impl synthetic data.XYObjectData, @@ -113,7 +113,7 @@ def main(progress=False): # Most common standardized way of computing the mean and std over observations # resized to 64px in size of dtype float32 in the range [0, 1]. data = data_cls(transform=ToImgTensorF32(size=64), **kwargs) - mean, std = compute_data_mean_std(data, progress=progress) + mean, std = compute_data_mean_std(data, progress=progress, num_workers=num_workers, batch_size=batch_size) # results! print(f'{data.__class__.__name__} - {data.name} - {kwargs}:\n mean: {mean.tolist()}\n std: {std.tolist()}') diff --git a/disent/frameworks/_framework.py b/disent/frameworks/_framework.py index 3b51557d..b1d4eeb3 100644 --- a/disent/frameworks/_framework.py +++ b/disent/frameworks/_framework.py @@ -22,6 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import logging from dataclasses import asdict from dataclasses import dataclass from dataclasses import fields @@ -32,15 +33,14 @@ from typing import final from typing import Optional from typing import Tuple -from typing import Type from typing import Union -import logging import torch from disent import registry -from disent.schedule import Schedule from disent.nn.modules import DisentLightningModule +from disent.schedule import Schedule +from disent.util.imports import import_obj log = logging.getLogger(__name__) @@ -83,7 +83,7 @@ class DisentFramework(DisentConfigurable, DisentLightningModule): @dataclass class cfg(DisentConfigurable.cfg): # optimizer config - optimizer: Union[str, Type[torch.optim.Optimizer]] = 'adam' + optimizer: Union[str] = 'adam' # name in the registry, eg. `adam` OR the path to an optimizer eg. `torch.optim.Adam` optimizer_kwargs: Optional[Dict[str, Union[str, float, int]]] = None def __init__( @@ -94,32 +94,54 @@ def __init__( ): # save the config values to the class super().__init__(cfg=cfg) - # get the optimizer - if isinstance(self.cfg.optimizer, str): - if self.cfg.optimizer not in registry.OPTIMIZERS: - raise KeyError(f'invalid optimizer: {repr(self.cfg.optimizer)}, valid optimizers are: {sorted(registry.OPTIMIZER)}, otherwise pass a torch.optim.Optimizer class instead.') - self.cfg.optimizer = registry.OPTIMIZERS[self.cfg.optimizer] - # check the optimizer values - assert callable(self.cfg.optimizer) - assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(self.cfg.optimizer_kwargs)}' - # set default values for optimizer - if self.cfg.optimizer_kwargs is None: - self.cfg.optimizer_kwargs = dict() - if 'lr' not in self.cfg.optimizer_kwargs: - self.cfg.optimizer_kwargs['lr'] = 1e-3 - log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`') + # check the optimizer + self.cfg.optimizer = self._check_optimizer(self.cfg.optimizer) + self.cfg.optimizer_kwargs = self._check_optimizer_kwargs(self.cfg.optimizer_kwargs) # batch augmentations may not be implemented as dataset # transforms so we can apply these on the GPU instead - assert callable(batch_augment) or (batch_augment is None) + assert callable(batch_augment) or (batch_augment is None), f'invalid batch_augment: {repr(batch_augment)}, must be callable or `None`' self._batch_augment = batch_augment # schedules # - maybe add support for schedules in the config? self._registered_schedules = set() self._active_schedules: Dict[str, Tuple[Any, Schedule]] = {} + @staticmethod + def _check_optimizer(optimizer: str): + if not isinstance(optimizer, str): + raise TypeError(f'invalid optimizer: {repr(optimizer)}, must be a `str`') + # check that the optimizer has been registered + # otherwise check that the optimizer class can be imported instead + if optimizer not in registry.OPTIMIZERS: + try: + import_obj(optimizer) + except ImportError: + raise KeyError(f'invalid optimizer: {repr(optimizer)}, valid optimizers are: {sorted(registry.OPTIMIZERS)}, or an import path to an optimizer, eg. `torch.optim.Adam`') + # return the updated values! + return optimizer + + @staticmethod + def _check_optimizer_kwargs(optimizer_kwargs: Optional[dict]): + # check the optimizer kwargs + assert isinstance(optimizer_kwargs, dict) or (optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(optimizer_kwargs)}' + # get default kwargs OR copy + optimizer_kwargs = dict() if (optimizer_kwargs is None) else dict(optimizer_kwargs) + # set default values + if 'lr' not in optimizer_kwargs: + optimizer_kwargs['lr'] = 1e-3 + log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`') + # return the updated values + return optimizer_kwargs + @final def configure_optimizers(self): - optimizer_cls = self.cfg.optimizer + # get the optimizer + # 1. first check if the name has been registered + # 2. then check if the name can be imported + if self.cfg.optimizer in registry.OPTIMIZERS: + optimizer_cls = registry.OPTIMIZERS[self.cfg.optimizer] + else: + optimizer_cls = import_obj(self.cfg.optimizer) # check that we can call the optimizer if not callable(optimizer_cls): raise TypeError(f'unsupported optimizer type: {type(optimizer_cls)}') @@ -133,31 +155,24 @@ def configure_optimizers(self): @final def _compute_loss_step(self, batch, batch_idx, update_schedules: bool): - # augment batch with GPU support - if self._batch_augment is not None: - batch = self._batch_augment(batch) - # update the config values based on registered schedules - if update_schedules: - # TODO: how do we handle this in the case of the validation and test step? I think this - # might still give the wrong results as this is based on the trainer.global_step which - # may be incremented by these steps. - self._update_config_from_schedules() - # compute loss - loss, logs_dict = self.do_training_step(batch, batch_idx) - # check returned values - assert 'loss' not in logs_dict - self._assert_valid_loss(loss) - # log returned values - logs_dict['loss'] = loss - self.log_dict(logs_dict) - # return loss - return loss - - @final - def training_step(self, batch, batch_idx): - """This is a pytorch-lightning function that should return the computed loss""" try: - return self._compute_loss_step(batch, batch_idx, update_schedules=True) + # augment batch with GPU support + if self._batch_augment is not None: + batch = self._batch_augment(batch) + # update the config values based on registered schedules + if update_schedules: + # TODO: how do we handle this in the case of the validation and test step? I think this + # might still give the wrong results as this is based on the trainer.global_step which + # may be incremented by these steps. + self._update_config_from_schedules() + # compute loss + # TODO: move logging into child frameworks? + loss = self.do_training_step(batch, batch_idx) + # check loss values + self._assert_valid_loss(loss) + self.log('loss', float(loss), prog_bar=True) + # return loss + return loss except Exception as e: # pragma: no cover # call in all the child processes for the best chance of clearing this... # remove callbacks from trainer so we aren't stuck running forever! @@ -180,11 +195,27 @@ def test_step(self, batch, batch_idx): """ return self._compute_loss_step(batch, batch_idx, update_schedules=False) + @final + def training_step(self, batch, batch_idx): + """This is a pytorch-lightning function that should return the computed loss""" + return self._compute_loss_step(batch, batch_idx, update_schedules=True) + + def validation_step(self, batch, batch_idx): + """ + TODO: how do we handle the schedule in this case? + """ + return self._compute_loss_step(batch, batch_idx, update_schedules=False) + + def test_step(self, batch, batch_idx): + """ + TODO: how do we handle the schedule in this case? + """ + return self._compute_loss_step(batch, batch_idx, update_schedules=False) + @final def _assert_valid_loss(self, loss): - if self.trainer.terminate_on_nan: - if torch.isnan(loss) or torch.isinf(loss): - raise ValueError('The returned loss is nan or inf') + if torch.isnan(loss) or torch.isinf(loss): + raise ValueError('The returned loss is nan or inf') if loss > 1e+20: raise ValueError(f'The returned loss: {loss:.2e} is out of bounds: > {1e+20:.0e}') @@ -192,7 +223,7 @@ def forward(self, batch) -> torch.Tensor: # pragma: no cover """this function should return the single final output of the model, including the final activation""" raise NotImplementedError - def do_training_step(self, batch, batch_idx) -> Tuple[torch.Tensor, Dict[str, Union[Number, torch.Tensor]]]: # pragma: no cover + def do_training_step(self, batch, batch_idx) -> torch.Tensor: # pragma: no cover """ should return a dictionary of items to log with the key 'train_loss' as the variable to minimize diff --git a/disent/frameworks/ae/_ae_mixin.py b/disent/frameworks/ae/_ae_mixin.py new file mode 100644 index 00000000..7b1584d0 --- /dev/null +++ b/disent/frameworks/ae/_ae_mixin.py @@ -0,0 +1,147 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +from dataclasses import dataclass +from typing import Dict +from typing import final +from typing import Tuple + +import torch + +from disent.frameworks import DisentFramework +from disent.frameworks.helper.reconstructions import make_reconstruction_loss +from disent.frameworks.helper.reconstructions import ReconLossHandler +from disent.model import AutoEncoder + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# framework_vae # +# ========================================================================= # + + +class _AeAndVaeMixin(DisentFramework): + """ + Base class containing common logic for both the Ae and Vae classes. + -- This is private because handling of both classes needs to be conducted differently. + An instance of a Vae should not be an instance of an Ae and vice-versa + """ + + @dataclass + class cfg(DisentFramework.cfg): + recon_loss: str = 'mse' + # multiple reduction modes exist for the various loss components. + # - 'sum': sum over the entire batch + # - 'mean': mean over the entire batch + # - 'mean_sum': sum each observation, returning the mean sum over the batch + loss_reduction: str = 'mean' + # disable various components + detach_decoder: bool = False + disable_rec_loss: bool = False + disable_aug_loss: bool = False + + # --------------------------------------------------------------------- # + # AE/VAE Attributes # + # --------------------------------------------------------------------- # + + @property + def REQUIRED_Z_MULTIPLIER(self) -> int: + raise NotImplementedError + + @property + def REQUIRED_OBS(self) -> int: + raise NotImplementedError + + @final + @property + def recon_handler(self) -> ReconLossHandler: + return self.__recon_handler + + # --------------------------------------------------------------------- # + # AE/VAE Init # + # --------------------------------------------------------------------- # + + # attributes provided by this class and initialised in _init_ae_mixin + _model: AutoEncoder + __recon_handler: ReconLossHandler + + def _init_ae_mixin(self, model: AutoEncoder): + # vae model + self._model = model + # check the model + assert isinstance(self._model, AutoEncoder), f'model must be an instance of {AutoEncoder.__name__}, got: {type(model)}' + assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}' + # recon loss & activation fn + self.__recon_handler: ReconLossHandler = make_reconstruction_loss(self.cfg.recon_loss, reduction=self.cfg.loss_reduction) + + # --------------------------------------------------------------------- # + # AE/VAE Training Step Helper # + # --------------------------------------------------------------------- # + + @final + def _get_xs_and_targs(self, batch: Dict[str, Tuple[torch.Tensor, ...]], batch_idx) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: + xs_targ = batch['x_targ'] + if 'x' not in batch: + # TODO: re-enable this warning but only ever print once! + # warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ') + xs = xs_targ + else: + xs = batch['x'] + # check that we have the correct number of inputs + if (len(xs) != self.REQUIRED_OBS) or (len(xs_targ) != self.REQUIRED_OBS): + log.warning(f'batch len(xs)={len(xs)} and len(xs_targ)={len(xs_targ)} observation count mismatch, requires: {self.REQUIRED_OBS}') + # done + return xs, xs_targ + + # --------------------------------------------------------------------- # + # AE/VAE Model Utility Functions (Visualisation) # + # --------------------------------------------------------------------- # + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Get the deterministic latent representation (useful for visualisation)""" + raise NotImplementedError + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent vector z into reconstruction x_recon (useful for visualisation)""" + raise NotImplementedError + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Feed through the full deterministic model (useful for visualisation)""" + raise NotImplementedError + + # --------------------------------------------------------------------- # + # AE/VAE Model Utility Functions (Training) # + # --------------------------------------------------------------------- # + + def decode_partial(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent vector z into partial reconstructions that exclude the final activation if there is one.""" + raise NotImplementedError + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index 8e7d2541..f8f4a204 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -23,7 +23,6 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging -import warnings from dataclasses import dataclass from numbers import Number from typing import Any @@ -35,23 +34,18 @@ import torch -from disent.frameworks import DisentFramework -from disent.frameworks.helper.reconstructions import make_reconstruction_loss -from disent.frameworks.helper.reconstructions import ReconLossHandler +from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin from disent.frameworks.helper.util import detach_all from disent.model import AutoEncoder from disent.util.iters import map_all -log = logging.getLogger(__name__) - - # ========================================================================= # # framework_vae # # ========================================================================= # -class Ae(DisentFramework): +class Ae(_AeAndVaeMixin): """ Basic Auto Encoder ------------------ @@ -73,56 +67,23 @@ class Ae(DisentFramework): * `compute_ave_recon_loss` """ + # override REQUIRED_Z_MULTIPLIER = 1 REQUIRED_OBS = 1 @dataclass - class cfg(DisentFramework.cfg): - recon_loss: str = 'mse' - # multiple reduction modes exist for the various loss components. - # - 'sum': sum over the entire batch - # - 'mean': mean over the entire batch - # - 'mean_sum': sum each observation, returning the mean sum over the batch - loss_reduction: str = 'mean' - # disable various components - disable_decoder: bool = False - disable_rec_loss: bool = False - disable_aug_loss: bool = False + class cfg(_AeAndVaeMixin.cfg): + pass def __init__(self, model: AutoEncoder, cfg: cfg = None, batch_augment=None): super().__init__(cfg=cfg, batch_augment=batch_augment) - # vae model - self._model = model - # check the model - assert isinstance(self._model, AutoEncoder) - assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}' - # recon loss & activation fn - self.__recon_handler: ReconLossHandler = make_reconstruction_loss(self.cfg.recon_loss, reduction=self.cfg.loss_reduction) - - @final - @property - def recon_handler(self) -> ReconLossHandler: - return self.__recon_handler + # initialise the auto-encoder mixin (recon handler, model, enc, dec, etc.) + self._init_ae_mixin(model=model) # --------------------------------------------------------------------- # # AE Training Step -- Overridable # # --------------------------------------------------------------------- # - @final - def _get_xs_and_targs(self, batch: Dict[str, Tuple[torch.Tensor, ...]], batch_idx) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: - xs_targ = batch['x_targ'] - if 'x' not in batch: - # TODO: re-enable this warning but only ever print once! - # warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ') - xs = xs_targ - else: - xs = batch['x'] - # check that we have the correct number of inputs - if (len(xs) != self.REQUIRED_OBS) or (len(xs_targ) != self.REQUIRED_OBS): - log.warning(f'batch len(xs)={len(xs)} and len(xs_targ)={len(xs_targ)} observation count mismatch, requires: {self.REQUIRED_OBS}') - # done - return xs, xs_targ - @final def do_training_step(self, batch, batch_idx): xs, xs_targ = self._get_xs_and_targs(batch, batch_idx) @@ -131,10 +92,10 @@ def do_training_step(self, batch, batch_idx): # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # # latent variables zs = map_all(self.encode, xs) - # intercept latent variables + # [HOOK] intercept latent variables zs, logs_intercept_zs = self.hook_ae_intercept_zs(zs) # reconstruct without the final activation - xs_partial_recon = map_all(self.decode_partial, detach_all(zs, if_=self.cfg.disable_decoder)) + xs_partial_recon = map_all(self.decode_partial, detach_all(zs, if_=self.cfg.detach_decoder)) # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # # LOSS @@ -149,14 +110,21 @@ def do_training_step(self, batch, batch_idx): if not self.cfg.disable_aug_loss: loss += aug_loss # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # return values - return loss, { + # log general + self.log_dict({ **logs_intercept_zs, **logs_recon, **logs_aug, - 'recon_loss': recon_loss, - 'aug_loss': aug_loss, - } + }) + + # log progress bar + self.log_dict({ + 'recon_loss': float(recon_loss), + 'aug_loss': float(aug_loss), + }, prog_bar=True) + + # return values + return loss # --------------------------------------------------------------------- # # Overrideable # diff --git a/disent/frameworks/helper/latent_distributions.py b/disent/frameworks/helper/latent_distributions.py index ad7ac0a9..cabf61c4 100644 --- a/disent/frameworks/helper/latent_distributions.py +++ b/disent/frameworks/helper/latent_distributions.py @@ -32,6 +32,7 @@ from torch.distributions import Laplace from torch.distributions import Normal +import disent.registry as R from disent.frameworks.helper.util import compute_ave_loss from disent.nn.loss.kl import kl_loss from disent.nn.loss.reduction import loss_reduction @@ -161,17 +162,8 @@ def encoding_to_dists(self, raw_z: Tuple[torch.Tensor, ...]) -> Tuple[Laplace, L # ========================================================================= # -_LATENT_HANDLERS = { - 'normal': LatentDistsHandlerNormal, - 'laplace': LatentDistsHandlerLaplace, -} - - def make_latent_distribution(name: str, kl_mode: str, reduction: str) -> LatentDistsHandler: - try: - cls = _LATENT_HANDLERS[name] - except KeyError: - raise KeyError(f'unknown vae distribution name: {repr(name)}, must be one of: {sorted(_LATENT_HANDLERS.keys())}') + cls = R.LATENT_HANDLERS[name] # make instance return cls(kl_mode=kl_mode, reduction=reduction) diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 723d764f..7da710bb 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -22,23 +22,21 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import re import warnings from typing import final -from typing import List from typing import Sequence -from typing import Tuple from typing import Union import torch import torch.nn.functional as F -from disent import registry +import disent.registry as R +from disent.dataset.transform import FftKernel from disent.frameworks.helper.util import compute_ave_loss from disent.nn.loss.reduction import batch_loss_reduction from disent.nn.loss.reduction import loss_reduction from disent.nn.modules import DisentModule -from disent.dataset.transform import FftKernel +from disent.util.deprecate import deprecated # ========================================================================= # @@ -239,17 +237,30 @@ def compute_unreduced_loss(self, x_recon, x_targ): # ========================================================================= # +_NO_ARG = object() + + class AugmentedReconLossHandler(ReconLossHandler): - def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], wrap_weight=1.0, aug_weight=1.0): + def __init__( + self, + recon_loss_handler: ReconLossHandler, + kernel: Union[str, torch.Tensor], + wrap_weight: float = 1.0, + aug_weight: float = 1.0, + normalize_mode: str = _NO_ARG + ): super().__init__(reduction=recon_loss_handler._reduction) # save variables self._recon_loss_handler = recon_loss_handler # must be a recon loss handler, but cannot nest augmented handlers assert isinstance(recon_loss_handler, ReconLossHandler) assert not isinstance(recon_loss_handler, AugmentedReconLossHandler) + # deprecation error + if normalize_mode is _NO_ARG: + raise ValueError(f'default argument for normalize_mode was "sum", this has been deprecated and will change to "none" in future. Please manually override this value!') # load the kernel - self._kernel = FftKernel(kernel=kernel, normalize=True) + self._kernel = FftKernel(kernel=kernel, normalize_mode=normalize_mode) # kernel weighting assert 0 <= wrap_weight, f'loss_weight must be in the range [0, inf) but received: {repr(wrap_weight)}' assert 0 <= aug_weight, f'kern_weight must be in the range [0, inf) but received: {repr(aug_weight)}' @@ -273,30 +284,31 @@ def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_t # ========================================================================= # # Registry & Factory # +# TODO: add ability to register parameterized reconstruction losses # ========================================================================= # -# TODO: add ability to register parameterized reconstruction losses -_ARG_RECON_LOSSES: List[Tuple[re.Pattern, str, callable]] = [ - # (REGEX, EXAMPLE, FACTORY_FUNC) - # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group - # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first. -] +def _make_aug_recon_loss_l_w_n(loss: str, kern: str, loss_weight: str, kernel_weight: str, normalize_mode: str): + def _loss(reduction: str): + return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=float(loss_weight), aug_weight=float(kernel_weight), normalize_mode=normalize_mode) + return _loss + + +def _make_aug_recon_loss_l1_w1_n(loss: str, kern: str, normalize_mode: str): + def _loss(reduction: str): + return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=1.0, aug_weight=1.0, normalize_mode=normalize_mode) + return _loss + + +def _make_aug_recon_loss_l1_w1_nnone(loss: str, kern: str): + def _loss(reduction: str): + return AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=1.0, aug_weight=1.0, normalize_mode='none') + return _loss # NOTE: this function compliments make_kernel in transform/_augment.py def make_reconstruction_loss(name: str, reduction: str) -> ReconLossHandler: - if name in registry.RECON_LOSSES: - # search normal losses! - return registry.RECON_LOSSES[name](reduction) - else: - # regex search losses, and call with args! - for r, _, fn in _ARG_RECON_LOSSES: - result = r.search(name) - if result is not None: - return fn(reduction, *result.groups()) - # we couldn't find anything - raise KeyError(f'Invalid vae reconstruction loss: {repr(name)} Valid losses include: {sorted(registry.RECON_LOSSES)}, examples of additional argument based losses include: {[example for _, example, _ in _ARG_RECON_LOSSES]}') + return R.RECON_LOSSES[name](reduction=reduction) # ========================================================================= # diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py index 366bd4e9..1ad85b33 100644 --- a/disent/frameworks/vae/_unsupervised__vae.py +++ b/disent/frameworks/vae/_unsupervised__vae.py @@ -27,21 +27,17 @@ from typing import Any from typing import Dict from typing import final -from typing import Optional from typing import Sequence from typing import Tuple from typing import Union import torch from torch.distributions import Distribution -from torch.distributions import Laplace -from torch.distributions import Normal -from disent.frameworks.ae._unsupervised__ae import Ae +from disent.frameworks.ae._ae_mixin import _AeAndVaeMixin from disent.frameworks.helper.latent_distributions import LatentDistsHandler from disent.frameworks.helper.latent_distributions import make_latent_distribution from disent.frameworks.helper.util import detach_all - from disent.util.iters import map_all @@ -50,7 +46,7 @@ # ========================================================================= # -class Vae(Ae): +class Vae(_AeAndVaeMixin): """ Variational Auto Encoder https://arxiv.org/abs/1312.6114 @@ -88,22 +84,23 @@ class Vae(Ae): Vae(hook_compute_ave_aug_loss=TripletLoss(), required_obs=3) """ - # override required z from AE + # overrides REQUIRED_Z_MULTIPLIER = 2 REQUIRED_OBS = 1 @dataclass - class cfg(Ae.cfg): + class cfg(_AeAndVaeMixin.cfg): # latent distribution settings latent_distribution: str = 'normal' kl_loss_mode: str = 'direct' # disable various components disable_reg_loss: bool = False - disable_posterior_scale: Optional[float] = None def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): # required_z_multiplier - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) + super().__init__(cfg=cfg, batch_augment=batch_augment) + # initialise the auto-encoder mixin (recon handler, model, enc, dec, etc.) + self._init_ae_mixin(model=model) # vae distribution self.__latents_handler = make_latent_distribution(self.cfg.latent_distribution, kl_mode=self.cfg.kl_loss_mode, reduction=self.cfg.loss_reduction) @@ -124,14 +121,12 @@ def do_training_step(self, batch, batch_idx): # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # # latent distribution parameterizations ds_posterior, ds_prior = map_all(self.encode_dists, xs, collect_returned=True) - # [HOOK] disable learnt scale values - ds_posterior, ds_prior = self._hook_intercept_ds_disable_scale(ds_posterior, ds_prior) # [HOOK] intercept latent parameterizations ds_posterior, ds_prior, logs_intercept_ds = self.hook_intercept_ds(ds_posterior, ds_prior) # sample from dists zs_sampled = tuple(d.rsample() for d in ds_posterior) # reconstruct without the final activation - xs_partial_recon = map_all(self.decode_partial, detach_all(zs_sampled, if_=self.cfg.disable_decoder)) + xs_partial_recon = map_all(self.decode_partial, detach_all(zs_sampled, if_=self.cfg.detach_decoder)) # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # # LOSS @@ -149,45 +144,23 @@ def do_training_step(self, batch, batch_idx): if not self.cfg.disable_reg_loss: loss += reg_loss # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # return values - return loss, { + # log general + self.log_dict({ **logs_intercept_ds, **logs_recon, **logs_reg, **logs_aug, - 'recon_loss': recon_loss, - 'reg_loss': reg_loss, - 'aug_loss': aug_loss, - # ratios - 'ratio_reg': (reg_loss / loss) if (loss != 0) else 0, - 'ratio_rec': (recon_loss / loss) if (loss != 0) else 0, - 'ratio_aug': (aug_loss / loss) if (loss != 0) else 0, - } + }) - # --------------------------------------------------------------------- # - # Delete AE Hooks # - # --------------------------------------------------------------------- # - - @final - def hook_ae_intercept_zs(self, zs: Sequence[torch.Tensor]) -> Tuple[Sequence[torch.Tensor], Dict[str, Any]]: - raise RuntimeError('This function should never be used or overridden by VAE methods!') # pragma: no cover + # log progress bar + self.log_dict({ + 'recon_loss': float(recon_loss), + 'reg_loss': float(reg_loss), + 'aug_loss': float(aug_loss), + }, prog_bar=True) - @final - def hook_ae_compute_ave_aug_loss(self, zs: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]: - raise RuntimeError('This function should never be used or overridden by VAE methods!') # pragma: no cover - - # --------------------------------------------------------------------- # - # Private Hooks # - # --------------------------------------------------------------------- # - - def _hook_intercept_ds_disable_scale(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]): - # disable posterior scales - if self.cfg.disable_posterior_scale is not None: - for d_posterior in ds_posterior: - assert isinstance(d_posterior, (Normal, Laplace)) - d_posterior.scale = torch.full_like(d_posterior.scale, fill_value=self.cfg.disable_posterior_scale) - # return modified values - return ds_posterior, ds_prior + # return values + return loss # --------------------------------------------------------------------- # # Overrideable Hooks # @@ -199,6 +172,14 @@ def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequ def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]: return 0, {} + def compute_ave_recon_loss(self, xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]: + # compute reconstruction loss + pixel_loss = self.recon_handler.compute_ave_loss_from_partial(xs_partial_recon, xs_targ) + # return logs + return pixel_loss, { + 'pixel_loss': pixel_loss + } + def compute_ave_reg_loss(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution], zs_sampled: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]: # compute regularization loss (kl divergence) kl_loss = self.latents_handler.compute_ave_kl_loss(ds_posterior, ds_prior, zs_sampled) @@ -208,7 +189,7 @@ def compute_ave_reg_loss(self, ds_posterior: Sequence[Distribution], ds_prior: S } # --------------------------------------------------------------------- # - # VAE - Encoding - Overrides AE # + # VAE Model Utility Functions (Visualisation) # # --------------------------------------------------------------------- # @final @@ -218,6 +199,20 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: z = self.latents_handler.encoding_to_representation(z_raw) return z + @final + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent vector z into reconstruction x_recon (useful for visualisation)""" + return self.recon_handler.activate(self._model.decode(z)) + + @final + def forward(self, batch: torch.Tensor) -> torch.Tensor: + """Feed through the full deterministic model (useful for visualisation)""" + return self.decode(self.encode(batch)) + + # --------------------------------------------------------------------- # + # VAE Model Utility Functions (Training) # + # --------------------------------------------------------------------- # + @final def encode_dists(self, x: torch.Tensor) -> Tuple[Distribution, Distribution]: """Get parametrisations of the latent distributions, which are sampled from during training.""" @@ -225,6 +220,11 @@ def encode_dists(self, x: torch.Tensor) -> Tuple[Distribution, Distribution]: z_posterior, z_prior = self.latents_handler.encoding_to_dists(z_raw) return z_posterior, z_prior + @final + def decode_partial(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent vector z into partial reconstructions that exclude the final activation if there is one.""" + return self._model.decode(z) + # ========================================================================= # # END # diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py index 98099e76..5cc8994a 100644 --- a/disent/metrics/__init__.py +++ b/disent/metrics/__init__.py @@ -33,24 +33,3 @@ # ========================================================================= # # Fast Metric Settings # # ========================================================================= # - - -# helper imports -from disent.util.function import wrapped_partial as _wrapped_partial - - -FAST_METRICS = { - 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), - 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds - 'mig': _wrapped_partial(metric_mig, num_train=2000), - 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), - 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), -} - -DEFAULT_METRICS = { - 'dci': metric_dci, - 'factor_vae': metric_factor_vae, - 'mig': metric_mig, - 'sap': metric_sap, - 'unsupervised': metric_unsupervised, -} diff --git a/disent/metrics/_dci.py b/disent/metrics/_dci.py index 2c48bcd5..7138bb17 100644 --- a/disent/metrics/_dci.py +++ b/disent/metrics/_dci.py @@ -35,6 +35,8 @@ import scipy import scipy.stats +from disent.metrics.utils import make_metric + log = logging.getLogger(__name__) @@ -44,6 +46,7 @@ # ========================================================================= # +@make_metric('dci', fast_kwargs=dict(num_train=1000, num_test=500)) def metric_dci( dataset: DisentDataset, representation_function: callable, @@ -87,10 +90,10 @@ def _compute_dci(mus_train, ys_train, mus_test, ys_test, boost_mode='sklearn', s assert importance_matrix.shape[0] == mus_train.shape[0] assert importance_matrix.shape[1] == ys_train.shape[0] return { - "dci.informativeness_train": train_err, - "dci.informativeness_test": test_err, - "dci.disentanglement": _disentanglement(importance_matrix), - "dci.completeness": _completeness(importance_matrix), + "dci.informativeness_train": train_err, # "dci.explicitness" -- Measuring Disentanglement: A Review of Metrics + "dci.informativeness_test": test_err, # "dci.explicitness" -- Measuring Disentanglement: A Review of Metrics + "dci.disentanglement": _disentanglement(importance_matrix), # "dci.modularity" -- Measuring Disentanglement: A Review of Metrics + "dci.completeness": _completeness(importance_matrix), # "dci.compactness" -- Measuring Disentanglement: A Review of Metrics } diff --git a/disent/metrics/_factor_vae.py b/disent/metrics/_factor_vae.py index 66ffc67a..5576a4f2 100644 --- a/disent/metrics/_factor_vae.py +++ b/disent/metrics/_factor_vae.py @@ -32,6 +32,7 @@ from disent.dataset import DisentDataset from disent.metrics import utils +from disent.metrics.utils import make_metric from disent.util import to_numpy @@ -43,6 +44,7 @@ # ========================================================================= # +@make_metric('factor_vae', fast_kwargs=dict(num_train=700, num_eval=350, num_variance_estimate=1000)) # may not be accurate, but it just takes waay too long otherwise 20+ seconds def metric_factor_vae( dataset: DisentDataset, representation_function: callable, @@ -124,8 +126,8 @@ def metric_factor_vae( eval_accuracy = np.sum(eval_votes[classifier, other_index]) * 1. / np.sum(eval_votes) return { - "factor_vae.train_accuracy": train_accuracy, - "factor_vae.eval_accuracy": eval_accuracy, + "factor_vae.train_accuracy": train_accuracy, # "z-min variance" -- Measuring Disentanglement: A Review of Metrics + "factor_vae.eval_accuracy": eval_accuracy, # "z-min variance" -- Measuring Disentanglement: A Review of Metrics "factor_vae.num_active_dims": len(active_dims), } diff --git a/disent/metrics/_mig.py b/disent/metrics/_mig.py index 3e0b1a87..4ccd0550 100644 --- a/disent/metrics/_mig.py +++ b/disent/metrics/_mig.py @@ -32,6 +32,7 @@ from disent.dataset import DisentDataset from disent.metrics import utils +from disent.metrics.utils import make_metric log = logging.getLogger(__name__) @@ -42,6 +43,7 @@ # ========================================================================= # +@make_metric('mig', fast_kwargs=dict(num_train=2000)) def metric_mig( dataset: DisentDataset, representation_function, @@ -76,5 +78,5 @@ def _compute_mig(mus_train, ys_train): entropy = utils.discrete_entropy(ys_train) sorted_m = np.sort(m, axis=0)[::-1] return { - "mig.discrete_score": np.mean(np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:])) + "mig.discrete_score": np.mean(np.divide(sorted_m[0, :] - sorted_m[1, :], entropy[:])) # "modularity: MIG" -- Measuring Disentanglement: A Review of Metrics } diff --git a/disent/metrics/_sap.py b/disent/metrics/_sap.py index 36fad48d..09722d06 100644 --- a/disent/metrics/_sap.py +++ b/disent/metrics/_sap.py @@ -33,6 +33,7 @@ from disent.dataset import DisentDataset from disent.metrics import utils +from disent.metrics.utils import make_metric log = logging.getLogger(__name__) @@ -43,6 +44,7 @@ # ========================================================================= # +@make_metric('sap', fast_kwargs=dict(num_train=2000, num_test=1000)) def metric_sap( dataset: DisentDataset, representation_function, @@ -80,7 +82,7 @@ def _compute_sap(mus, ys, mus_test, ys_test, continuous_factors): sap_score = _compute_avg_diff_top_two(score_matrix) log.debug("SAP score: %.2g", sap_score) return { - "sap.score": sap_score + "sap.score": sap_score # "compactness: SAP" -- Measuring Disentanglement: A Review of Metrics } diff --git a/disent/metrics/_unsupervised.py b/disent/metrics/_unsupervised.py index c872f110..e351e2ef 100644 --- a/disent/metrics/_unsupervised.py +++ b/disent/metrics/_unsupervised.py @@ -31,6 +31,8 @@ from disent.dataset import DisentDataset from disent.metrics import utils +from disent.metrics.utils import make_metric + log = logging.getLogger(__name__) @@ -40,6 +42,7 @@ # ========================================================================= # +@make_metric('unsupervised', fast_kwargs=dict(num_train=2000)) def metric_unsupervised( dataset: DisentDataset, representation_function, diff --git a/disent/metrics/utils.py b/disent/metrics/utils.py index e8c6b2fe..9bd28569 100644 --- a/disent/metrics/utils.py +++ b/disent/metrics/utils.py @@ -23,16 +23,94 @@ Utility functions that are useful for the different metrics. """ +from numbers import Number +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Optional +from typing import Protocol +from typing import TypeVar +from typing import Union + import numpy as np import sklearn from tqdm import tqdm from disent.dataset import DisentDataset from disent.util import to_numpy +from disent.util.function import wrapped_partial + + +# ========================================================================= # +# Metric Wrapper # +# ========================================================================= # + + +T = TypeVar('T') + + +class Metric(Generic[T]): + + def __init__( + self, + name: str, + metric_fn: T, # Callable[[...], Dict[str, Number]] + default_kwargs: Optional[Dict[str, Any]] = None, + fast_kwargs: Optional[Dict[str, Any]] = None, + ): + self._name = name + self._orig_fn = metric_fn + self._metric_fn_default = wrapped_partial(self._orig_fn, **(default_kwargs if default_kwargs else {})) + self._metric_fn_fast = wrapped_partial(self._orig_fn, **(fast_kwargs if fast_kwargs else {})) + + # How do we get a type hint for `__call__` so that its signature matches `T`? + def __call__(self, *args, **kwargs) -> Dict[str, Number]: + return self._metric_fn_default(*args, **kwargs) + + @property + def compute(self) -> T: + return self._metric_fn_default + + @property + def compute_fast(self) -> T: + return self._metric_fn_fast + + @property + def unwrap(self) -> T: + return self._orig_fn + + @property + def name(self) -> str: + return self._name + + def __str__(self): + return f'metric-{self.name}' + + +def make_metric( + name: str, + default_kwargs: Optional[Dict[str, Any]] = None, + fast_kwargs: Optional[Dict[str, Any]] = None, +) -> Callable[[T], Union[Metric[T], T]]: + """ + Metrics should be decorated using this function to set defaults! + Two versions of the metric should exist. + 1. Recommended settings + - This should give reliable results, but may be very slow, multiple minutes to half an + hour or more for some metrics depending on the underlying model, data and ground-truth factors. + 2. Faster settings + - This should give a decent results, but should be decently fast, a few seconds/minutes at most. + This is not used for testing + """ + # `Union[Metric[T], T]` is hack to get type hint on `__call__` + def _wrap_fn_as_metric(metric_fn: T) -> Union[Metric[T], T]: + return Metric(name=name, metric_fn=metric_fn, default_kwargs=default_kwargs, fast_kwargs=fast_kwargs) + return _wrap_fn_as_metric # ========================================================================= # -# utils # +# utils # # ========================================================================= # diff --git a/disent/nn/functional/__init__.py b/disent/nn/functional/__init__.py index 4ac90f16..8af43526 100644 --- a/disent/nn/functional/__init__.py +++ b/disent/nn/functional/__init__.py @@ -47,6 +47,12 @@ from disent.nn.functional._mean import torch_mean_geometric from disent.nn.functional._mean import torch_mean_harmonic +from disent.nn.functional._norm import torch_norm +from disent.nn.functional._norm import torch_dist +from disent.nn.functional._norm import torch_norm_euclidean +from disent.nn.functional._norm import torch_norm_manhattan +from disent.nn.functional._norm import torch_dist_hamming + from disent.nn.functional._other import torch_normalize from disent.nn.functional._other import torch_nan_to_num from disent.nn.functional._other import torch_unsqueeze_l diff --git a/disent/nn/functional/_mean.py b/disent/nn/functional/_mean.py index 00bdf0a0..6493994e 100644 --- a/disent/nn/functional/_mean.py +++ b/disent/nn/functional/_mean.py @@ -41,12 +41,14 @@ _NEG_INF = float('-inf') _GENERALIZED_MEAN_MAP = { + 'inf': _POS_INF, 'maximum': _POS_INF, 'quadratic': 2, 'arithmetic': 1, 'geometric': 0, 'harmonic': -1, 'minimum': _NEG_INF, + '-inf': _NEG_INF, } @@ -55,7 +57,7 @@ # ========================================================================= # -def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1, keepdim: bool = False): +def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[float, str] = 1, keepdim: bool = False): """ Compute the generalised mean. - p is the power @@ -69,9 +71,9 @@ def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[ p = _GENERALIZED_MEAN_MAP[p] # compute the specific extreme cases if p == _POS_INF: - return torch.max(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.max(xs, keepdim=keepdim) + return torch.amax(xs, dim=dim, keepdim=keepdim) elif p == _NEG_INF: - return torch.min(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.min(xs, keepdim=keepdim) + return torch.amin(xs, dim=dim, keepdim=keepdim) # compute the number of elements being averaged if dim is None: dim = list(range(xs.ndim)) diff --git a/disent/nn/functional/_norm.py b/disent/nn/functional/_norm.py new file mode 100644 index 00000000..6c972482 --- /dev/null +++ b/disent/nn/functional/_norm.py @@ -0,0 +1,129 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import warnings +from typing import List +from typing import Optional +from typing import Union + +import torch + + +# ========================================================================= # +# Types # +# ========================================================================= # + + +_DimTypeHint = Optional[Union[int, List[int]]] + +_POS_INF = float('inf') +_NEG_INF = float('-inf') + +_P_NORM_MAP = { + 'inf': _POS_INF, + 'maximum': _POS_INF, + 'euclidean': 2, + 'manhattan': 1, + # p < 1 is not a valid norm! but allow it if we set `unbounded_p = True` + 'hamming': 0, + 'minimum': _NEG_INF, + '-inf': _NEG_INF, +} + + +# ========================================================================= # +# p-Norm Functions # +# -- very similar to the generalised mean # +# ========================================================================= # + + +def torch_dist(xs: torch.Tensor, dim: _DimTypeHint = -1, p: Union[float, str] = 1, keepdim: bool = False): + """ + Like torch_norm, but allows arbitrary p values that may + result in the returned values not being a valid norm. + - norm's require p >= 1 + """ + if isinstance(p, str): + p = _P_NORM_MAP[p] + # get absolute values + xs = torch.abs(xs) + # compute the specific extreme cases + # -- its kind of odd that the p-norm and generalised mean converge to the + # same values, just from different directions! + if p == _POS_INF: + return torch.amax(xs, dim=dim, keepdim=keepdim) + elif p == _NEG_INF: + return torch.amin(xs, dim=dim, keepdim=keepdim) + # get the dimensions + if dim is None: + dim = list(range(xs.ndim)) + # warn if the type is wrong + if p != 1: + if xs.dtype != torch.float64: + warnings.warn(f'Input tensor to p-norm might not have the required precision, type is {xs.dtype} not {torch.float64}.') + # compute the specific cases + if p == 0: + # hamming distance -- number of non-zero entries of the vector + return torch.sum(xs != 0, dim=dim, keepdim=keepdim, dtype=xs.dtype) + if p == 1: + # manhattan distance + return torch.sum(xs, dim=dim, keepdim=keepdim) + else: + # p-norm (if p==2, then euclidean distance) + return torch.sum(xs ** p, dim=dim, keepdim=keepdim) ** (1/p) + + +def torch_norm(xs: torch.Tensor, dim: _DimTypeHint = -1, p: Union[float, str] = 1, keepdim: bool = False): + """ + Compute the generalised p-norm over the given dimension of a vector! + - p values must be >= 1 + + Closely related to the generalised mean. + - p-norm: (sum(|x|^p)) ^ (1/p) + - gen-mean: (1/n * sum(x ^ p)) ^ (1/p) + """ + if isinstance(p, str): + p = _P_NORM_MAP[p] + # check values + if p < 1: + raise ValueError(f'p-norm cannot have a p value less than 1, got: {repr(p)}, to bypass this error set `unbounded_p=True`.') + # return norm + return torch_dist(xs=xs, dim=dim, p=p, keepdim=keepdim) + + +def torch_norm_euclidean(xs, dim: _DimTypeHint = -1, keepdim: bool = False): + return torch_dist(xs, dim=dim, p='euclidean', keepdim=keepdim) + + +def torch_norm_manhattan(xs, dim: _DimTypeHint = -1, keepdim: bool = False): + return torch_dist(xs, dim=dim, p='manhattan', keepdim=keepdim) + + +def torch_dist_hamming(xs, dim: _DimTypeHint = -1, keepdim: bool = False): + return torch_dist(xs, dim=dim, p='hamming', keepdim=keepdim) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_pca.py b/disent/nn/functional/_pca.py index 33212f41..4cb40d76 100644 --- a/disent/nn/functional/_pca.py +++ b/disent/nn/functional/_pca.py @@ -30,14 +30,16 @@ # ========================================================================= # -def torch_pca_eig(X, center=True, scale=False): +def torch_pca_eig(X, center=True, scale=False, zero_negatives=False): """ perform PCA over X - X is of size (num_points, vec_size) NOTE: unlike PCA_svd, the number of vectors/values returned is always: vec_size + + WARNING: this may be incorrect! """ - n, _ = X.shape + n, m = X.shape # center points along axes if center: X = X - X.mean(dim=0) @@ -47,10 +49,14 @@ def torch_pca_eig(X, center=True, scale=False): scaling = torch.sqrt(1 / torch.diagonal(covariance)) covariance = torch.mm(torch.diagflat(scaling), covariance) # compute eigen values and eigen vectors - eigenvalues, eigenvectors = torch.eig(covariance, True) + eigenvalues, eigenvectors = torch.linalg.eig(covariance) # sort components by decreasing variance - components = eigenvectors.T - explained_variance = eigenvalues[:, 0] + components = torch.real(eigenvectors.T) # TODO: handle imaginary numbers! + explained_variance = torch.real(eigenvalues) # TODO: handle imaginary numbers! + # handle n < m -- numerical stability issues return negative values! + # maybe this should just zero out the negatives instead, they don't contribute!? + explained_variance = torch.abs(explained_variance) + # return sorted idxs = torch.argsort(explained_variance, descending=True) return components[idxs], explained_variance[idxs] @@ -75,7 +81,10 @@ def torch_pca_svd(X, center=True): return components, explained_variance -def torch_pca(X, center=True, mode='svd'): +def torch_pca(X, center=True, mode='svd') -> (torch.Tensor, torch.Tensor): + # number of values returned may differ depending on the method! + # -- svd returns: min(num, z_size) + # -- eig returns: num if mode == 'svd': return torch_pca_svd(X, center=center) elif mode == 'eig': diff --git a/disent/nn/loss/reduction.py b/disent/nn/loss/reduction.py index a1f08efc..ea021b33 100644 --- a/disent/nn/loss/reduction.py +++ b/disent/nn/loss/reduction.py @@ -78,6 +78,7 @@ def get_mean_loss_scale(x: torch.Tensor, reduction: str): } +# TODO: this is duplicated in research ... pairwise_loss # applying this function and then taking the # mean should give the same result as loss_reduction def batch_loss_reduction(tensor: torch.Tensor, reduction_dtype=None, reduction='mean') -> torch.Tensor: diff --git a/disent/nn/loss/softsort.py b/disent/nn/loss/softsort.py index 2f50d16e..19158ae4 100644 --- a/disent/nn/loss/softsort.py +++ b/disent/nn/loss/softsort.py @@ -104,7 +104,7 @@ def torch_soft_sort( dims: Union[int, Tuple[int, ...]] = -1, regularization='l2', regularization_strength=1.0, - dims_at_end=False, + leave_dims_at_end=False, ): # we import it locally so that we don't have to install this import torchsort @@ -113,7 +113,7 @@ def torch_soft_sort( # sort the last dimension of the 2D tensors tensor = torchsort.soft_sort(tensor, regularization=regularization, regularization_strength=regularization_strength) # undo the reorder operation - if dims_at_end: + if leave_dims_at_end: return tensor return torch_undo_dims_at_end_2d(tensor, moved_shape=moved_shape, moved_end_dims=moved_end_dims) diff --git a/disent/nn/loss/triplet.py b/disent/nn/loss/triplet.py index f1e55f9d..4ae29afd 100644 --- a/disent/nn/loss/triplet.py +++ b/disent/nn/loss/triplet.py @@ -70,7 +70,7 @@ def dist_triplet_sigmoid_loss(pos_delta, neg_delta, margin_min=None, margin_max= https://arxiv.org/pdf/2003.14021.pdf """ if margin_min is not None: - warnings.warn('triplet_loss does not support margin_min') + warnings.warn('triplet_sigmoid_loss does not support margin_min') p_dist = torch.norm(pos_delta, p=p, dim=-1) n_dist = torch.norm(neg_delta, p=p, dim=-1) loss = torch.sigmoid((1/margin_max) * (p_dist - n_dist)) @@ -80,6 +80,32 @@ def dist_triplet_sigmoid_loss(pos_delta, neg_delta, margin_min=None, margin_max= # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # +def triplet_soft_loss(anc, pos, neg, margin_min=None, margin_max=None, p=1): + """ + Triplet Loss With Soft-Margin + https://arxiv.org/pdf/1703.07737.pdf + """ + return dist_triplet_soft_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p) + + +def dist_triplet_soft_loss(pos_delta, neg_delta, margin_min=None, margin_max=None, p=1): + """ + Triplet Loss With Soft-Margin + https://arxiv.org/pdf/1703.07737.pdf + """ + if margin_min is not None: + warnings.warn('triplet_soft_loss does not support margin_min') + if margin_max is not None: + warnings.warn('triplet_soft_loss does not support margin_max') + p_dist = torch.norm(pos_delta, p=p, dim=-1) + n_dist = torch.norm(neg_delta, p=p, dim=-1) + loss = torch.log(1 + torch.exp(p_dist - n_dist)) + return loss.mean() + + +# -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # + + # def elem_triplet_loss(anc, pos, neg, margin_min=None, margin_max=1., p=1): # """ # Element-Wise Triplet Loss @@ -132,6 +158,7 @@ def min_clamped_triplet_loss(anc, pos, neg, margin_min=0.01, margin_max=1., p=1) """ Min Margin Triplet Loss TODO: is this better, or clamped_triplet_loss? + TODO: could take idea from soft-margin to make this continuously differentiable? """ return dist_min_clamped_triplet_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p) @@ -140,6 +167,7 @@ def dist_min_clamped_triplet_loss(pos_delta, neg_delta, margin_min=0.01, margin_ """ Min Margin Triplet Loss TODO: is this better, or dist_clamped_triplet_loss? + TODO: could take idea from soft-margin to make this continuously differentiable? """ p_dist = torch.norm(pos_delta, p=p, dim=-1) n_dist = torch.norm(neg_delta, p=p, dim=-1) @@ -153,6 +181,7 @@ def split_clamped_triplet_loss(anc, pos, neg, margin_min=0.01, margin_max=1., p= """ Min Margin Triplet Loss TODO: is this better, or min_clamp_triplet_loss? + TODO: could take idea from soft-margin to make this continuously differentiable? """ return dist_split_clamped_triplet_loss(anc - pos, anc - neg, margin_min=margin_min, margin_max=margin_max, p=p) @@ -161,6 +190,7 @@ def dist_split_clamped_triplet_loss(pos_delta, neg_delta, margin_min=0.01, margi """ Min Margin Triplet Loss TODO: is this better, or dist_min_clamp_triplet_loss? + TODO: could take idea from soft-margin to make this continuously differentiable? """ p_dist = torch.norm(pos_delta, p=p, dim=-1) n_dist = torch.norm(neg_delta, p=p, dim=-1) @@ -212,12 +242,13 @@ class TripletLossConfig(object): triplet_margin_min: float = 0.1 triplet_margin_max: float = 10 triplet_scale: float = 100 - triplet_p: int = 2 + triplet_p: float = 2 _TRIPLET_LOSSES = { 'triplet': triplet_loss, 'triplet_sigmoid': triplet_sigmoid_loss, + 'triplet_soft': triplet_soft_loss, # 'elem_triplet': elem_triplet_loss, # 'min_margin_triplet': min_margin_triplet_loss, 'min_clamped_triplet': min_clamped_triplet_loss, @@ -229,6 +260,7 @@ class TripletLossConfig(object): _DIST_TRIPLET_LOSSES = { 'triplet': dist_triplet_loss, 'triplet_sigmoid': dist_triplet_sigmoid_loss, + 'triplet_soft': dist_triplet_soft_loss, # 'elem_triplet': dist_elem_triplet_loss, # 'min_margin_triplet': dist_min_margin_triplet_loss, 'min_clamped_triplet': dist_min_clamped_triplet_loss, @@ -281,11 +313,3 @@ def compute_dist_triplet_loss(zs_deltas: Sequence[torch.Tensor], cfg: TripletCon # ========================================================================= # # END # # ========================================================================= # - - - - - - - - diff --git a/disent/nn/loss/triplet_mining.py b/disent/nn/loss/triplet_mining.py index 84f71ce9..0eaa0ac6 100644 --- a/disent/nn/loss/triplet_mining.py +++ b/disent/nn/loss/triplet_mining.py @@ -157,8 +157,13 @@ def configured_idx_mine( cfg: SampledTripletMineCfgProto, pairwise_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], # should return arrays with ndim == 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # TODO: SIMPLIFY THIS FUNCTION HIERARCHY, THERE ARE A LOT OF UNNECESSARY CALLS! # TODO: this function is quite useless, its easier to just use configured_mine_random_mode + # skip mining if mode is None! + if cfg.overlap_mine_triplet_mode == 'none': + return a_idxs, p_idxs, n_idxs # compute differences + # TODO: this is computationally expensive! sometimes the dist_ap and dist_an may not be used depending on the mode! dist_ap = pairwise_loss_fn(x_targ[a_idxs], x_targ[p_idxs]) dist_an = pairwise_loss_fn(x_targ[a_idxs], x_targ[n_idxs]) # mine indices diff --git a/disent/nn/modules.py b/disent/nn/modules.py index 9de5f574..fdeac6cb 100644 --- a/disent/nn/modules.py +++ b/disent/nn/modules.py @@ -43,11 +43,12 @@ def forward(self, *args, **kwargs): class DisentLightningModule(pl.LightningModule): - - def _forward_unimplemented(self, *args): - # Annoying fix applied by torch for Module.forward: - # https://github.com/python/mypy/issues/8795 - raise RuntimeError('This should never run!') + # make sure we don't get complaints about the missing methods! + # -- we prefer to use LightningDataModule + train_dataloader = None + test_dataloader = None + val_dataloader = None + predict_dataloader = None # ========================================================================= # diff --git a/disent/nn/weights.py b/disent/nn/weights.py index 1a09b852..6c3ef56c 100644 --- a/disent/nn/weights.py +++ b/disent/nn/weights.py @@ -25,6 +25,7 @@ import logging from typing import Optional +import torch from torch import nn from disent.util.strings import colors as c @@ -37,6 +38,24 @@ # ========================================================================= # +_WEIGHT_INIT_FNS = { + 'xavier_uniform': lambda weight: nn.init.xavier_uniform_(weight, gain=1.0), # gain=1 + 'xavier_normal': lambda weight: nn.init.xavier_normal_(weight, gain=1.0), # gain=1 + 'xavier_normal__0.1': lambda weight: nn.init.xavier_normal_(weight, gain=0.1), # gain=0.1 + # kaiming -- also known as "He initialisation" + 'kaiming_uniform': lambda weight: nn.init.kaiming_uniform_(weight, a=0, mode='fan_in', nonlinearity='relu'), # fan_in, relu + 'kaiming_normal': lambda weight: nn.init.kaiming_normal_(weight, a=0, mode='fan_in', nonlinearity='relu'), # fan_in, relu + 'kaiming_normal__fan_out': lambda weight: nn.init.kaiming_normal_(weight, a=0, mode='fan_out', nonlinearity='relu'), # fan_in, relu + # other + 'orthogonal': lambda weight: nn.init.orthogonal_(weight, gain=1), # gain=1 + 'normal': lambda weight: nn.init.normal_(weight, mean=0., std=1.), # gain=1 + 'normal__0.1': lambda weight: nn.init.normal_(weight, mean=0., std=0.1), # gain=0.1 + 'normal__0.01': lambda weight: nn.init.normal_(weight, mean=0., std=0.01), # gain=0.01 + 'normal__0.001': lambda weight: nn.init.normal_(weight, mean=0., std=0.001), # gain=0.01 +} + + +# TODO: clean this up! this is terrible... def init_model_weights(model: nn.Module, mode: Optional[str] = 'xavier_normal', log_level=logging.INFO) -> nn.Module: count = 0 @@ -44,20 +63,20 @@ def init_model_weights(model: nn.Module, mode: Optional[str] = 'xavier_normal', if mode is None: mode = 'default' - def init_normal(m): + def _apply_init_weights(m): nonlocal count init, count = False, count + 1 # actually initialise! - if mode == 'xavier_normal': + if mode == 'default': + pass + elif mode in _WEIGHT_INIT_FNS: if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - nn.init.xavier_normal_(m.weight) + _WEIGHT_INIT_FNS[mode](m.weight) nn.init.zeros_(m.bias) init = True - elif mode == 'default': - pass else: - raise KeyError(f'Unknown init mode: {repr(mode)}, valid modes are: {["xavier_normal", "default"]}') + raise KeyError(f'Unknown init mode: {repr(mode)}, valid modes are: {["default"] + sorted(_WEIGHT_INIT_FNS)}') # print messages if init: @@ -66,7 +85,7 @@ def init_normal(m): log.log(log_level, f'| {count:03d} {c.lRED}SKIP{c.RST}: {m.__class__.__name__}') log.log(log_level, f'Initialising Model Layers: {mode}') - model.apply(init_normal) + model.apply(_apply_init_weights) return model diff --git a/disent/registry/__init__.py b/disent/registry/__init__.py index 82814540..d834f1a8 100644 --- a/disent/registry/__init__.py +++ b/disent/registry/__init__.py @@ -34,8 +34,18 @@ eg. `DATASET.register(...options...)(your_function_or_class)` """ -from disent.registry._registry import Registry as _Registry -from disent.registry._registry import LazyImport as _LazyImport +# from disent.registry._registry import ProvidedValue +# from disent.registry._registry import StaticImport +# from disent.registry._registry import DictProviders +# from disent.registry._registry import RegexProvidersSearch + +from disent.registry._registry import StaticValue +from disent.registry._registry import LazyValue +from disent.registry._registry import LazyImport +from disent.registry._registry import Registry +from disent.registry._registry import RegistryImports +from disent.registry._registry import RegexConstructor +from disent.registry._registry import RegexRegistry # ========================================================================= # @@ -44,15 +54,19 @@ # TODO: this is not yet used in disent.data or disent.frameworks -DATASETS = _Registry('DATASETS') +DATASETS: RegistryImports['torch.utils.data.Dataset'] = RegistryImports('DATASETS') # groundtruth -- impl -DATASETS['cars3d'] = _LazyImport('disent.dataset.data._groundtruth__cars3d') -DATASETS['dsprites'] = _LazyImport('disent.dataset.data._groundtruth__dsprites') -DATASETS['mpi3d'] = _LazyImport('disent.dataset.data._groundtruth__mpi3d') -DATASETS['smallnorb'] = _LazyImport('disent.dataset.data._groundtruth__norb') -DATASETS['shapes3d'] = _LazyImport('disent.dataset.data._groundtruth__shapes3d') +DATASETS['cars3d_x128'] = LazyImport('disent.dataset.data._groundtruth__cars3d.Cars3dData') +DATASETS['cars3d'] = LazyImport('disent.dataset.data._groundtruth__cars3d.Cars3d64Data') +DATASETS['dsprites'] = LazyImport('disent.dataset.data._groundtruth__dsprites.DSpritesData') +DATASETS['mpi3d_toy'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='toy') +DATASETS['mpi3d_realistic'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='realistic') +DATASETS['mpi3d_real'] = LazyImport('disent.dataset.data._groundtruth__mpi3d.Mpi3dData', subset='real') +DATASETS['smallnorb_x96'] = LazyImport('disent.dataset.data._groundtruth__norb.SmallNorbData') +DATASETS['smallnorb'] = LazyImport('disent.dataset.data._groundtruth__norb.SmallNorb64Data') +DATASETS['shapes3d'] = LazyImport('disent.dataset.data._groundtruth__shapes3d.Shapes3dData') # groundtruth -- impl synthetic -DATASETS['xyobject'] = _LazyImport('disent.dataset.data._groundtruth__xyobject') +DATASETS['xyobject'] = LazyImport('disent.dataset.data._groundtruth__xyobject.XYObjectData') # ========================================================================= # @@ -63,18 +77,18 @@ # TODO: this is not yet used in disent.data or disent.frameworks # changes here should also update -SAMPLERS = _Registry('SAMPLERS') +SAMPLERS: RegistryImports['disent.dataset.sampling.BaseDisentSampler'] = RegistryImports('SAMPLERS') # [ground truth samplers] -SAMPLERS['gt_dist'] = _LazyImport('disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler') -SAMPLERS['gt_pair'] = _LazyImport('disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler') -SAMPLERS['gt_pair_orig'] = _LazyImport('disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler') -SAMPLERS['gt_single'] = _LazyImport('disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler') -SAMPLERS['gt_triple'] = _LazyImport('disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler') +SAMPLERS['gt_dist'] = LazyImport('disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler') +SAMPLERS['gt_pair'] = LazyImport('disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler') +SAMPLERS['gt_pair_orig'] = LazyImport('disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler') +SAMPLERS['gt_single'] = LazyImport('disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler') +SAMPLERS['gt_triple'] = LazyImport('disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler') # [any dataset samplers] -SAMPLERS['single'] = _LazyImport('disent.dataset.sampling._single.SingleSampler') -SAMPLERS['random'] = _LazyImport('disent.dataset.sampling._random__any.RandomSampler') +SAMPLERS['single'] = LazyImport('disent.dataset.sampling._single.SingleSampler') +SAMPLERS['random'] = LazyImport('disent.dataset.sampling._random__any.RandomSampler') # [episode samplers] -SAMPLERS['random_episode'] = _LazyImport('disent.dataset.sampling._random__episodes.RandomEpisodeSampler') +SAMPLERS['random_episode'] = LazyImport('disent.dataset.sampling._random__episodes.RandomEpisodeSampler') # ========================================================================= # @@ -87,19 +101,19 @@ # TODO: this is not yet used in disent.frameworks -FRAMEWORKS = _Registry('FRAMEWORKS') +FRAMEWORKS: RegistryImports['disent.frameworks.DisentFramework'] = RegistryImports('FRAMEWORKS') # [AE] -FRAMEWORKS['tae'] = _LazyImport('disent.frameworks.ae._supervised__tae.TripletAe') -FRAMEWORKS['ae'] = _LazyImport('disent.frameworks.ae._unsupervised__ae.Ae') +FRAMEWORKS['tae'] = LazyImport('disent.frameworks.ae._supervised__tae.TripletAe') +FRAMEWORKS['ae'] = LazyImport('disent.frameworks.ae._unsupervised__ae.Ae') # [VAE] -FRAMEWORKS['tvae'] = _LazyImport('disent.frameworks.vae._supervised__tvae.TripletVae') -FRAMEWORKS['betatc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae') -FRAMEWORKS['beta_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betavae.BetaVae') -FRAMEWORKS['dfc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dfcvae.DfcVae') -FRAMEWORKS['dip_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dipvae.DipVae') -FRAMEWORKS['info_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae') -FRAMEWORKS['vae'] = _LazyImport('disent.frameworks.vae._unsupervised__vae.Vae') -FRAMEWORKS['ada_vae'] = _LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae') +FRAMEWORKS['tvae'] = LazyImport('disent.frameworks.vae._supervised__tvae.TripletVae') +FRAMEWORKS['betatc_vae'] = LazyImport('disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae') +FRAMEWORKS['beta_vae'] = LazyImport('disent.frameworks.vae._unsupervised__betavae.BetaVae') +FRAMEWORKS['dfc_vae'] = LazyImport('disent.frameworks.vae._unsupervised__dfcvae.DfcVae') +FRAMEWORKS['dip_vae'] = LazyImport('disent.frameworks.vae._unsupervised__dipvae.DipVae') +FRAMEWORKS['info_vae'] = LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae') +FRAMEWORKS['vae'] = LazyImport('disent.frameworks.vae._unsupervised__vae.Vae') +FRAMEWORKS['ada_vae'] = LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae') # ========================================================================= # @@ -108,27 +122,30 @@ # ========================================================================= # -RECON_LOSSES = _Registry('RECON_LOSSES') +RECON_LOSSES: RegexRegistry['disent.frameworks.helper.reconstructions.ReconLossHandler'] = RegexRegistry('RECON_LOSSES') # TODO: we need a regex version of RegistryImports # [STANDARD LOSSES] -RECON_LOSSES['mse'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMse') # from the normal distribution - real values in the range [0, 1] -RECON_LOSSES['mae'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMae') # mean absolute error +RECON_LOSSES['mse'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMse') # from the normal distribution - real values in the range [0, 1] +RECON_LOSSES['mae'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMae') # mean absolute error # [STANDARD DISTRIBUTIONS] -RECON_LOSSES['bce'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBce') # from the bernoulli distribution - binary values in the set {0, 1} -RECON_LOSSES['bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli') # reduces to bce - binary values in the set {0, 1} -RECON_LOSSES['c_bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli') # bernoulli with a computed offset to handle values in the range [0, 1] -RECON_LOSSES['normal'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerNormal') # handle all real values +RECON_LOSSES['bce'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBce') # from the bernoulli distribution - binary values in the set {0, 1} +RECON_LOSSES['bernoulli'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli') # reduces to bce - binary values in the set {0, 1} +RECON_LOSSES['c_bernoulli'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli') # bernoulli with a computed offset to handle values in the range [0, 1] +RECON_LOSSES['normal'] = LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerNormal') # handle all real values +# [REGEX LOSSES] +RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_l(\d+\.\d+)_k(\d+\.\d+)_norm_([a-z]+)$', example='mse_xy8_abs63_l1.0_k1.0_norm_none', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l_w_n') +RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_norm_([a-z]+)$', example='mse_xy8_abs63_norm_none', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l1_w1_n') +RECON_LOSSES.register_regex(pattern=r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)$', example='mse_xy8_abs63', factory_fn='disent.frameworks.helper.reconstructions._make_aug_recon_loss_l1_w1_nnone') # ========================================================================= # -# LATENT_DISTS - should be synchronized with: # -# `disent/frameworks/helper/latent_distributions.py` # +# LATENT_HANDLERS - should be synchronized with: # +# `disent/frameworks/helper/latent_distributions.py` # # ========================================================================= # -# TODO: this is not yet used in disent.frameworks or disent.frameworks.helper.latent_distributions -LATENT_DISTS = _Registry('LATENT_DISTS') -LATENT_DISTS['normal'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal') -LATENT_DISTS['laplace'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace') +LATENT_HANDLERS: RegistryImports['disent.frameworks.helper.latent_distributions.LatentDistsHandler'] = RegistryImports('LATENT_HANDLERS') +LATENT_HANDLERS['normal'] = LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal') +LATENT_HANDLERS['laplace'] = LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace') # ========================================================================= # @@ -139,41 +156,40 @@ # default learning rate for each optimizer _LR = 1e-3 - -OPTIMIZERS = _Registry('OPTIMIZERS') +OPTIMIZERS: RegistryImports['torch.optim.Optimizer'] = RegistryImports('OPTIMIZERS') # [torch] -OPTIMIZERS['adadelta'] = _LazyImport(lr=_LR, import_path='torch.optim.adadelta.Adadelta') -OPTIMIZERS['adagrad'] = _LazyImport(lr=_LR, import_path='torch.optim.adagrad.Adagrad') -OPTIMIZERS['adam'] = _LazyImport(lr=_LR, import_path='torch.optim.adam.Adam') -OPTIMIZERS['adamax'] = _LazyImport(lr=_LR, import_path='torch.optim.adamax.Adamax') -OPTIMIZERS['adam_w'] = _LazyImport(lr=_LR, import_path='torch.optim.adamw.AdamW') -OPTIMIZERS['asgd'] = _LazyImport(lr=_LR, import_path='torch.optim.asgd.ASGD') -OPTIMIZERS['lbfgs'] = _LazyImport(lr=_LR, import_path='torch.optim.lbfgs.LBFGS') -OPTIMIZERS['rmsprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rmsprop.RMSprop') -OPTIMIZERS['rprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rprop.Rprop') -OPTIMIZERS['sgd'] = _LazyImport(lr=_LR, import_path='torch.optim.sgd.SGD') -OPTIMIZERS['sparse_adam'] = _LazyImport(lr=_LR, import_path='torch.optim.sparse_adam.SparseAdam') +OPTIMIZERS['adadelta'] = LazyImport(lr=_LR, import_path='torch.optim.adadelta.Adadelta') +OPTIMIZERS['adagrad'] = LazyImport(lr=_LR, import_path='torch.optim.adagrad.Adagrad') +OPTIMIZERS['adam'] = LazyImport(lr=_LR, import_path='torch.optim.adam.Adam') +OPTIMIZERS['adamax'] = LazyImport(lr=_LR, import_path='torch.optim.adamax.Adamax') +OPTIMIZERS['adam_w'] = LazyImport(lr=_LR, import_path='torch.optim.adamw.AdamW') +OPTIMIZERS['asgd'] = LazyImport(lr=_LR, import_path='torch.optim.asgd.ASGD') +OPTIMIZERS['lbfgs'] = LazyImport(lr=_LR, import_path='torch.optim.lbfgs.LBFGS') +OPTIMIZERS['rmsprop'] = LazyImport(lr=_LR, import_path='torch.optim.rmsprop.RMSprop') +OPTIMIZERS['rprop'] = LazyImport(lr=_LR, import_path='torch.optim.rprop.Rprop') +OPTIMIZERS['sgd'] = LazyImport(lr=_LR, import_path='torch.optim.sgd.SGD') +OPTIMIZERS['sparse_adam'] = LazyImport(lr=_LR, import_path='torch.optim.sparse_adam.SparseAdam') # [torch_optimizer] -OPTIMIZERS['acc_sgd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AccSGD') -OPTIMIZERS['ada_bound'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaBound') -OPTIMIZERS['ada_mod'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaMod') -OPTIMIZERS['adam_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdamP') -OPTIMIZERS['agg_mo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AggMo') -OPTIMIZERS['diff_grad'] = _LazyImport(lr=_LR, import_path='torch_optimizer.DiffGrad') -OPTIMIZERS['lamb'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Lamb') +OPTIMIZERS['acc_sgd'] = LazyImport(lr=_LR, import_path='torch_optimizer.AccSGD') +OPTIMIZERS['ada_bound'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdaBound') +OPTIMIZERS['ada_mod'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdaMod') +OPTIMIZERS['adam_p'] = LazyImport(lr=_LR, import_path='torch_optimizer.AdamP') +OPTIMIZERS['agg_mo'] = LazyImport(lr=_LR, import_path='torch_optimizer.AggMo') +OPTIMIZERS['diff_grad'] = LazyImport(lr=_LR, import_path='torch_optimizer.DiffGrad') +OPTIMIZERS['lamb'] = LazyImport(lr=_LR, import_path='torch_optimizer.Lamb') # 'torch_optimizer.Lookahead' is skipped because it is wrapped -OPTIMIZERS['novograd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.NovoGrad') -OPTIMIZERS['pid'] = _LazyImport(lr=_LR, import_path='torch_optimizer.PID') -OPTIMIZERS['qh_adam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHAdam') -OPTIMIZERS['qhm'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHM') -OPTIMIZERS['radam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RAdam') -OPTIMIZERS['ranger'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Ranger') -OPTIMIZERS['ranger_qh'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerQH') -OPTIMIZERS['ranger_va'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerVA') -OPTIMIZERS['sgd_w'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDW') -OPTIMIZERS['sgd_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDP') -OPTIMIZERS['shampoo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Shampoo') -OPTIMIZERS['yogi'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Yogi') +OPTIMIZERS['novograd'] = LazyImport(lr=_LR, import_path='torch_optimizer.NovoGrad') +OPTIMIZERS['pid'] = LazyImport(lr=_LR, import_path='torch_optimizer.PID') +OPTIMIZERS['qh_adam'] = LazyImport(lr=_LR, import_path='torch_optimizer.QHAdam') +OPTIMIZERS['qhm'] = LazyImport(lr=_LR, import_path='torch_optimizer.QHM') +OPTIMIZERS['radam'] = LazyImport(lr=_LR, import_path='torch_optimizer.RAdam') +OPTIMIZERS['ranger'] = LazyImport(lr=_LR, import_path='torch_optimizer.Ranger') +OPTIMIZERS['ranger_qh'] = LazyImport(lr=_LR, import_path='torch_optimizer.RangerQH') +OPTIMIZERS['ranger_va'] = LazyImport(lr=_LR, import_path='torch_optimizer.RangerVA') +OPTIMIZERS['sgd_w'] = LazyImport(lr=_LR, import_path='torch_optimizer.SGDW') +OPTIMIZERS['sgd_p'] = LazyImport(lr=_LR, import_path='torch_optimizer.SGDP') +OPTIMIZERS['shampoo'] = LazyImport(lr=_LR, import_path='torch_optimizer.Shampoo') +OPTIMIZERS['yogi'] = LazyImport(lr=_LR, import_path='torch_optimizer.Yogi') # ========================================================================= # @@ -182,12 +198,12 @@ # TODO: this is not yet used in disent.util.lightning.callbacks or disent.metrics -METRICS = _Registry('METRICS') -METRICS['dci'] = _LazyImport('disent.metrics._dci.metric_dci') -METRICS['factor_vae'] = _LazyImport('disent.metrics._factor_vae.metric_factor_vae') -METRICS['mig'] = _LazyImport('disent.metrics._mig.metric_mig') -METRICS['sap'] = _LazyImport('disent.metrics._sap.metric_sap') -METRICS['unsupervised'] = _LazyImport('disent.metrics._unsupervised.metric_unsupervised') +METRICS: RegistryImports['disent.metrics.utils._Metric'] = RegistryImports('METRICS') +METRICS['dci'] = LazyImport('disent.metrics._dci.metric_dci') +METRICS['factor_vae'] = LazyImport('disent.metrics._factor_vae.metric_factor_vae') +METRICS['mig'] = LazyImport('disent.metrics._mig.metric_mig') +METRICS['sap'] = LazyImport('disent.metrics._sap.metric_sap') +METRICS['unsupervised'] = LazyImport('disent.metrics._unsupervised.metric_unsupervised') # ========================================================================= # @@ -196,12 +212,12 @@ # TODO: this is not yet used in disent.framework or disent.schedule -SCHEDULES = _Registry('SCHEDULES') -SCHEDULES['clip'] = _LazyImport('disent.schedule._schedule.ClipSchedule') -SCHEDULES['cosine_wave'] = _LazyImport('disent.schedule._schedule.CosineWaveSchedule') -SCHEDULES['cyclic'] = _LazyImport('disent.schedule._schedule.CyclicSchedule') -SCHEDULES['linear'] = _LazyImport('disent.schedule._schedule.LinearSchedule') -SCHEDULES['noop'] = _LazyImport('disent.schedule._schedule.NoopSchedule') +SCHEDULES: RegistryImports['disent.schedule.Schedule'] = RegistryImports('SCHEDULES') +SCHEDULES['clip'] = LazyImport('disent.schedule._schedule.ClipSchedule') +SCHEDULES['cosine_wave'] = LazyImport('disent.schedule._schedule.CosineWaveSchedule') +SCHEDULES['cyclic'] = LazyImport('disent.schedule._schedule.CyclicSchedule') +SCHEDULES['linear'] = LazyImport('disent.schedule._schedule.LinearSchedule') +SCHEDULES['noop'] = LazyImport('disent.schedule._schedule.NoopSchedule') # ========================================================================= # @@ -210,17 +226,28 @@ # TODO: this is not yet used in disent.framework or disent.model -MODELS = _Registry('MODELS') +MODELS: RegistryImports['disent.model._base.DisentLatentsModule'] = RegistryImports('MODELS') # [DECODER] -MODELS['encoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.EncoderConv64') -MODELS['encoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.EncoderConv64Norm') -MODELS['encoder_fc'] = _LazyImport('disent.model.ae._vae_fc.EncoderFC') -MODELS['encoder_linear'] = _LazyImport('disent.model.ae._linear.EncoderLinear') +MODELS['encoder_conv64'] = LazyImport('disent.model.ae._vae_conv64.EncoderConv64') +MODELS['encoder_conv64norm'] = LazyImport('disent.model.ae._norm_conv64.EncoderConv64Norm') +MODELS['encoder_fc'] = LazyImport('disent.model.ae._vae_fc.EncoderFC') +MODELS['encoder_linear'] = LazyImport('disent.model.ae._linear.EncoderLinear') # [ENCODER] -MODELS['decoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.DecoderConv64') -MODELS['decoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.DecoderConv64Norm') -MODELS['decoder_fc'] = _LazyImport('disent.model.ae._vae_fc.DecoderFC') -MODELS['decoder_linear'] = _LazyImport('disent.model.ae._linear.DecoderLinear') +MODELS['decoder_conv64'] = LazyImport('disent.model.ae._vae_conv64.DecoderConv64') +MODELS['decoder_conv64norm'] = LazyImport('disent.model.ae._norm_conv64.DecoderConv64Norm') +MODELS['decoder_fc'] = LazyImport('disent.model.ae._vae_fc.DecoderFC') +MODELS['decoder_linear'] = LazyImport('disent.model.ae._linear.DecoderLinear') + + +# ========================================================================= # +# HELPER registries # +# ========================================================================= # + + +# TODO: add norm support with regex? +KERNELS: RegexRegistry['torch.Tensor'] = RegexRegistry('KERNELS') +KERNELS.register_regex(pattern=r'^box_r(\d+)$', example='box_r31', factory_fn='disent.dataset.transform._augment._make_box_kernel') +KERNELS.register_regex(pattern=r'^gau_r(\d+)$', example='gau_r31', factory_fn='disent.dataset.transform._augment._make_gaussian_kernel') # ========================================================================= # @@ -229,16 +256,17 @@ # registry of registries -REGISTRIES = _Registry('REGISTRIES') -REGISTRIES['DATASETS'] = DATASETS -REGISTRIES['SAMPLERS'] = SAMPLERS -REGISTRIES['FRAMEWORKS'] = FRAMEWORKS -REGISTRIES['RECON_LOSSES'] = RECON_LOSSES -REGISTRIES['LATENT_DISTS'] = LATENT_DISTS -REGISTRIES['OPTIMIZERS'] = OPTIMIZERS -REGISTRIES['METRICS'] = METRICS -REGISTRIES['SCHEDULES'] = SCHEDULES -REGISTRIES['MODELS'] = MODELS +REGISTRIES: Registry[Registry] = Registry('REGISTRIES') +REGISTRIES['DATASETS'] = StaticValue(DATASETS) +REGISTRIES['SAMPLERS'] = StaticValue(SAMPLERS) +REGISTRIES['FRAMEWORKS'] = StaticValue(FRAMEWORKS) +REGISTRIES['RECON_LOSSES'] = StaticValue(RECON_LOSSES) +REGISTRIES['LATENT_HANDLERS'] = StaticValue(LATENT_HANDLERS) +REGISTRIES['OPTIMIZERS'] = StaticValue(OPTIMIZERS) +REGISTRIES['METRICS'] = StaticValue(METRICS) +REGISTRIES['SCHEDULES'] = StaticValue(SCHEDULES) +REGISTRIES['MODELS'] = StaticValue(MODELS) +REGISTRIES['KERNELS'] = StaticValue(KERNELS) # ========================================================================= # diff --git a/disent/registry/_registry.py b/disent/registry/_registry.py index dbf4d911..05afba22 100644 --- a/disent/registry/_registry.py +++ b/disent/registry/_registry.py @@ -22,202 +22,663 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import inspect +import re +from abc import ABC from typing import Any from typing import Callable from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import MutableMapping from typing import NoReturn from typing import Optional -from typing import Sequence +from typing import Protocol +from typing import Set from typing import Tuple from typing import TypeVar +from typing import Union from disent.util.function import wrapped_partial -from disent.util.imports import import_obj_partial from disent.util.imports import _check_and_split_path +from disent.util.imports import import_obj +from disent.util.imports import import_obj_partial # ========================================================================= # -# Basic Cached Item # +# Type Hints # # ========================================================================= # +K = TypeVar('K') +V = TypeVar('V') T = TypeVar('T') +AliasesHint = Union[str, Tuple[str, ...]] + + +class _FactoryFn(Protocol[V]): + def __call__(self, *args) -> V: ... + + + + +# ========================================================================= # +# Provided Values # +# ========================================================================= # + + +class ProvidedValue(Generic[V], ABC): + """ + Base class for providing immutable values using the `get` method. + - Subclasses should override this + """ + + def get(self) -> V: + raise NotImplementedError + + def __repr__(self): + return f'{self.__class__.__name__}()' + +class StaticValue(ProvidedValue[V]): + """ + Provide static values. Simply a see-through wrapper + around already generated / constant values. + """ -class LazyValue(object): + def __init__(self, value: V): + self._value = value - def __init__(self, generate_fn: Callable[[], T]): + def get(self) -> V: + return self._value + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self._value)})' + + +class StaticImport(StaticValue[V]): + def __init__(self, fn: V, *partial_args, **partial_kwargs): + super().__init__(wrapped_partial(fn, *partial_args, **partial_kwargs)) + + +class LazyValue(ProvidedValue[V]): + """ + Use a function to provide a value by generating and caching + the result only when this value is first needed. + """ + + def __init__(self, generate_fn: Callable[[], V]): assert callable(generate_fn) - self._is_generated = False self._generate_fn = generate_fn + self._is_generated = False self._value = None - def generate(self) -> T: - # replace value -- we don't actually need caching of the - # values since the registry replaces these items automatically, - # but LazyValue is exposed so it might be used unexpectedly. + def get(self) -> V: + # cache the value if not self._is_generated: self._is_generated = True self._value = self._generate_fn() - self._generate_fn = None - # return value return self._value + def clear(self): + self._is_generated = False + self._value = None + def __repr__(self): return f'{self.__class__.__name__}({repr(self._generate_fn)})' +class LazyImport(LazyValue[V]): + """ + Like lazy value, but instead takes in the import path to a callable object. + Any remaining args and kwargs are used to partially parameterize the object. + - The partial object is returned, and not called. This should be + the same as importing the value directly when `get` is called! + """ + + def __init__(self, import_path: str, *partial_args, **partial_kwargs): + # function imports the object when called + def generate_fn(): + return import_obj_partial(import_path, *partial_args, **partial_kwargs) + # initialise the lazy value + super().__init__(generate_fn=generate_fn) + + # ========================================================================= # -# Import Helper # +# Provided Dict # # ========================================================================= # -class LazyImport(LazyValue): - def __init__(self, import_path: str, *partial_args, **partial_kwargs): - super().__init__( - generate_fn=lambda: import_obj_partial(import_path, *partial_args, **partial_kwargs), - ) +class DictProviders(MutableMapping[K, V]): + """ + A dictionary that only allows instances of + provided values to be added to it. + - The returned values are obtained directly from the providers + """ + + def __init__(self): + self._providers: Dict[K, ProvidedValue[V]] = {} + + def __getitem__(self, k: K) -> V: + return self._getitem(k) + + def __contains__(self, k: K): + return k in self._providers + + def __setitem__(self, k: K, v: ProvidedValue[V]) -> NoReturn: + self._setitem(k, v) + + def __delitem__(self, k: K) -> NoReturn: + del self._providers[k] + + def __len__(self) -> int: + return len(self._providers) + + def __iter__(self) -> Iterator[K]: + yield from self._providers + + # allow easier overriding in subclasses without calling super() which can get confusing + + def _getitem(self, k: K) -> V: + provider = self._providers[k] + return provider.get() + + def _setitem(self, k: K, v: ProvidedValue[V]) -> NoReturn: + if not isinstance(v, ProvidedValue): + raise TypeError(f'Values stored in {self.__class__.__name__} must be instances of: {ProvidedValue.__name__}, got: {repr(v)}') + self._providers[k] = v # ========================================================================= # -# Registry # +# Registry - Mixin # # ========================================================================= # -_NONE = object() +class Registry(DictProviders[str, V]): + def __init__(self, name: str): + if not str.isidentifier(name): + raise ValueError(f'Registry names must be valid identifiers, got: {repr(name)}') + # initialise + self._name = name + super().__init__() -class Registry(object): + @property + def static_examples(self) -> List[str]: + return list(self._providers.keys()) - def __init__( - self, - name: str, - assert_valid_key: Optional[Callable[[str], NoReturn]] = None, - assert_valid_value: Optional[Callable[[T], NoReturn]] = None, - ): - # checks! - assert str.isidentifier(name), f'The given name for the registry is not a valid identifier: {repr(name)}' - self._name = name - assert (assert_valid_key is None) or callable(assert_valid_key), f'assert_valid_key must be None or callable' - assert (assert_valid_value is None) or callable(assert_valid_value), f'assert_valid_value must be None or callable' - self._assert_valid_key = assert_valid_key - self._assert_valid_value = assert_valid_value - # storage - self._keys_to_values: Dict[str, Any] = {} + @property + def examples(self) -> List[str]: + return self.static_examples @property def name(self) -> str: return self._name - def _get_aliases(self, name, aliases, auto_alias: bool): - if auto_alias: - if name not in self: - aliases = (name, *aliases) - elif not aliases: - raise RuntimeError(f'automatic alias: {repr(name)} already exists but no alternative aliases were specified.') + def __repr__(self): + return f'{self.__class__.__name__}({self._name})' + + # --- CORE --- # + + def __setitem__(self, aliases: AliasesHint, v: ProvidedValue[V]) -> NoReturn: + self._setitems(aliases, v) + + def __getitem__(self, k: str) -> V: + value = self._getitem(k) + self._check_provided_value(value) + return value + + def __delitem__(self, k: str) -> None: + raise RuntimeError(f'Registry: {repr(self.name)} does not support item deletion. Tried to remove key: {repr(k)}') + + # --- HELPER --- # + + def _setitems(self, aliases: AliasesHint, v: Union[V, ProvidedValue[V]]) -> None: + aliases = self._normalise_aliases(aliases) + # check all the aliases + for k in aliases: + if not str.isidentifier(k): + raise ValueError(f'Keys stored in registry: {repr(self.name)} must be valid identifiers, got: {repr(k)}') + if k in self: + raise RuntimeError(f'Tried to overwrite existing key: {repr(k)} in registry: {repr(self.name)}') + self._check_key(k) + # check the value + v = self._check_and_normalise_value(v) + # set all the aliases + for k in aliases: + self._setitem(k, v) + + def _normalise_aliases(self, aliases: AliasesHint, check_nonempty: bool = True) -> Tuple[str]: + if isinstance(aliases, str): + aliases = (aliases,) + if not isinstance(aliases, tuple): + raise TypeError(f'Multiple aliases must be provided to registry: {repr(self.name)} as a Tuple[str], got: {repr(aliases)}') + if check_nonempty: + if len(aliases) < 1: + raise ValueError(f'At least one alias must be provided to registry: {repr(self.name)}, got: {repr(aliases)}') return aliases + # --- OVERRIDABLE --- # + + def _check_and_normalise_value(self, v: ProvidedValue[V]) -> ProvidedValue[V]: + return v + + def _check_provided_value(self, v: V) -> NoReturn: + pass + + def _check_key(self, k: str) -> NoReturn: + pass + + # --- MISSING VALUES --- # + + def setmissing(self, alias: AliasesHint, value: V) -> NoReturn: + # find missing keys + aliases = self._normalise_aliases(alias) + missing = tuple(alias for alias in aliases if (alias not in self)) + # register missing keys + if missing: + self._setitems(missing, value) + + @property + def setm(self) -> '_RegistrySetMissing': + # instead of checking values manually, at the cost of some efficiency, + # this allows us to register values multiple times with hardly modified notation! + # -- only modifies unset values + # set once: `REGISTRY['key'] = val` + # set default: `REGISTRY.setm['key'] = val` + return self._RegistrySetMissing(self) + + class _RegistrySetMissing(object): + def __init__(self, registry: 'Registry'): + self._registry = registry + + def __setitem__(self, aliases: str, v: ProvidedValue[V]) -> NoReturn: + self._registry.setmissing(aliases, v) + + +# ========================================================================= # +# Import Registry # +# ========================================================================= # + + +# TODO: merge this with the dynamic registry below? +class RegistryImports(Registry[V]): + """ + A registry for arbitrary imports. + -- supports decorating functions and classes + """ + def register( self, - fn=_NONE, - aliases: Sequence[str] = (), + aliases: Optional[AliasesHint] = None, auto_alias: bool = True, partial_args: Tuple[Any, ...] = None, partial_kwargs: Dict[str, Any] = None, ) -> Callable[[T], T]: + """ + Register a function or object to this registry. + - can be used as a decorator @register(...) + - automatically chooses an alias based on the function name + - specify defaults for the function with the args and kwargs + """ + # default values + if aliases is None: aliases = () + if partial_args is None: partial_args = () + if partial_kwargs is None: partial_kwargs = {} + aliases = self._normalise_aliases(aliases, check_nonempty=False) + + # add the function name as an alias if it does not already exist, + # then register the partially parameterised function as a static value def _decorator(orig_fn): - # try add the name of the object - keys = self._get_aliases(orig_fn.__name__, aliases=aliases, auto_alias=auto_alias) - # wrap function - new_fn = orig_fn - if (partial_args is not None) or (partial_kwargs is not None): - new_fn = wrapped_partial( - orig_fn, - *(() if partial_args is None else partial_args), - **({} if partial_kwargs is None else partial_kwargs), - ) - # register the function - self.register_value(value=new_fn, aliases=keys) + keys = self._append_auto_alias(self._get_fn_alias(orig_fn), aliases=aliases, auto_alias=auto_alias) + self[keys] = StaticImport(orig_fn, *partial_args, **partial_kwargs) return orig_fn - # handle case - if fn is _NONE: - return _decorator - else: - return _decorator(fn) + return _decorator def register_import( self, import_path: str, - aliases: Sequence[str] = (), + aliases: Optional[AliasesHint] = None, auto_alias: bool = True, *partial_args, **partial_kwargs, - ) -> 'Registry': - (*_, name) = _check_and_split_path(import_path) - return self.register_value( - value=LazyImport(import_path=import_path, *partial_args, **partial_kwargs), - aliases=self._get_aliases(name, aliases=aliases, auto_alias=auto_alias), - ) - - def register_value(self, value: T, aliases: Sequence[str]) -> 'Registry': - # check keys - if len(aliases) < 1: - raise ValueError(f'aliases must be specified, got an empty sequence') - for k in aliases: - if not str.isidentifier(k): - raise ValueError(f'alias is not a valid identifier: {repr(k)}') - if k in self._keys_to_values: - raise RuntimeError(f'registry already contains key: {repr(k)}') - self.assert_valid_key(k) - # handle lazy values -- defer checking if a lazy value - if not isinstance(value, LazyValue): - self.assert_valid_value(value) - # register keys - for k in aliases: - self._keys_to_values[k] = value - return self + ) -> NoReturn: + """ + Register an import path and automatically obtain an alias from it. + - This is the same as: registry[(import_name, *aliases)] = LazyImport(import_path, *partial_args, **partial_kwargs) + """ + # normalise aliases + if aliases is None: aliases = () + aliases = self._normalise_aliases(aliases, check_nonempty=False) + # add object alias + (*_, alias) = _check_and_split_path(import_path) + aliases = self._append_auto_alias(alias, aliases=aliases, auto_alias=auto_alias) + # register the lazy import + self[aliases] = LazyImport(import_path=import_path, *partial_args, **partial_kwargs) + + # --- ALIAS HELPER --- # + + def _append_auto_alias(self, alias: Optional[str], aliases: Tuple[str, ...], auto_alias: bool): + if auto_alias: + if alias is not None: + if alias not in self: + aliases = (alias, *aliases) + elif not aliases: + raise RuntimeError(f'automatic alias: {repr(alias)} already exists for registry: {repr(self.name)} and no alternative aliases were specified.') + elif not aliases: + raise RuntimeError(f'Cannot add value to registry: {repr(self.name)}, no automatic alias was found!') + elif not aliases: + raise RuntimeError(f'Cannot add value to registry: {repr(self.name)}, no manual aliases were specified and automatic aliasing is disabled!') + return aliases - def __setitem__(self, aliases: str, value: T): - if isinstance(aliases, str): - aliases = (aliases,) - assert isinstance(aliases, tuple), f'multiple aliases must be provided as a Tuple[str], got: {repr(aliases)}' - self.register_value(value=value, aliases=aliases) - - def __contains__(self, key: str): - return key in self._keys_to_values - - def __getitem__(self, key: str): - if key not in self._keys_to_values: - raise KeyError(f'registry does not contain the key: {repr(key)}, valid keys include: {sorted(self._keys_to_values.keys())}') - # get the value - value = self._keys_to_values[key] - # replace lazy value - if isinstance(value, LazyValue): - value = value.generate() - # check value & run deferred checks - if isinstance(value, LazyValue): - raise RuntimeError(f'{LazyValue.__name__} should never return other lazy values.') - self.assert_valid_value(value) - # update the value - self._keys_to_values[key] = value - # return the value - return value + @staticmethod + def _get_fn_alias(fn) -> Optional[str]: + if hasattr(fn, '__name__'): + if str.isidentifier(fn.__name__): + return fn.__name__ + return None - def __iter__(self): - yield from self._keys_to_values.keys() + # --- OVERRIDDEN --- # - def assert_valid_value(self, value: T) -> T: - if self._assert_valid_value is not None: - self._assert_valid_value(value) - return value + def _check_and_normalise_value(self, v: ProvidedValue[V]) -> ProvidedValue[V]: + if not isinstance(v, (LazyImport, StaticImport)): + raise TypeError(f'Values stored in registry: {repr(self.name)} must be instances of: {(LazyImport.__name__, StaticImport.__name__)}, got: {repr(v)}') + return v - def assert_valid_key(self, key: str) -> str: - if self._assert_valid_key is not None: - self._assert_valid_key(key) - return key + # --- OVERRIDABLE --- # - def __repr__(self): - return f'{self.__class__.__name__}({repr(self._name)}, ...)' + def _check_provided_value(self, v: V) -> NoReturn: + pass + + def _check_key(self, k: str) -> NoReturn: + pass + + +# ========================================================================= # +# Dynamic Registry # +# ========================================================================= # + +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # +# Constructor # +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + +class RegexConstructor(object): + + def __init__( + self, + pattern: Union[str, re.Pattern], + example: str, + factory_fn: Union[_FactoryFn[V], str], + ): + self._pattern = self._check_pattern(pattern) + self._example = self._check_example(example, self._pattern) + # we can delay loading of the function if it is a string! + self._factory_fn = factory_fn if isinstance(factory_fn, str) else self._check_factory_fn(factory_fn, self._pattern) + + @classmethod + def _check_pattern(cls, pattern: Union[str, re.Pattern]): + # check the regex type & convert + if isinstance(pattern, str): + pattern = re.compile(pattern) + if not isinstance(pattern, re.Pattern): + raise TypeError(f'regex pattern must be a regex `str` or `re.Pattern`, got: {repr(pattern)}') + if pattern.groups < 1: + raise ValueError(f'regex pattern must contain at least one group, got: {repr(pattern)}') + return pattern + + @classmethod + def _check_factory_fn(cls, factory_fn: _FactoryFn[V], pattern: re.Pattern) -> _FactoryFn[V]: + # we have an actual function, we can check it! + if not callable(factory_fn): + raise TypeError(f'generator function must be callable, got: {factory_fn}') + signature = inspect.signature(factory_fn) + if len(signature.parameters) != pattern.groups: + raise ValueError(f'signature has incorrect number of parameters: {repr(signature)} compared to the number of groups in the regex pattern: {repr(pattern)}') + return factory_fn + + @classmethod + def _check_example(cls, example: str, pattern: re.Pattern) -> str: + # check the example + if not isinstance(example, str): + raise TypeError(f'example must be a `str`, got: {type(example)}') + if not example: + raise ValueError(f'example must not be empty, got: {type(example)}') + # check that the regex matches the example! + if pattern.search(example) is None: + raise ValueError(f'could not match example: {repr(example)} to regex: {repr(pattern)}') + return example + + @property + def pattern(self) -> re.Pattern: + return self._pattern + + @property + def example(self) -> str: + return self._example + + def construct(self, name: str) -> V: + # get the results + result = self._pattern.search(name) + if result is None: + raise KeyError(f'pattern: {self.pattern} does not match given name: {repr(name)}. The following example would be valid: {repr(self.example)}') + # get the function -- load via the path + if isinstance(self._factory_fn, str): + fn = import_obj(self._factory_fn) + self._factory_fn = self._check_factory_fn(fn, self._pattern) + # construct + return self._factory_fn(*result.groups()) + + def can_construct(self, name: str) -> bool: + return self._pattern.search(name) is not None + + +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # +# Cached Linear Search # +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + +class RegexProvidersSearch(object): + + def __init__(self): + self._patterns: Set[re.Pattern] = set() + self._constructors: List[RegexConstructor] = [] + # caching + self._cache = {} + self._cache_dirty = False + + @property + def regex_constructors(self) -> List[RegexConstructor]: + return list(self._constructors) + + def construct(self, arg_str: str): + provider = self.get_constructor(arg_str) + # build the object + if provider is not None: + return provider.construct(arg_str) + # no result was found! + raise KeyError(f'could not construct an item from the given argument string: {repr(arg_str)}, valid patterns include: {[p.pattern for p in self._constructors]}') + + def can_construct(self, arg_str: str) -> bool: + return self.get_constructor(arg_str) is not None + + def get_constructor(self, arg_str: str) -> Optional[RegexConstructor]: + # TODO: clean up this cache! + # check cache -- remove None entries if dirty + if self._cache_dirty: + self._cache = {k: v for k, v in self._cache.items() if v is not None} + self._cache_dirty = False + if arg_str in self._cache: + return self._cache[arg_str] + # check the input string + if not isinstance(arg_str, str): + raise TypeError(f'regex factory can only construct from `str`, got: {repr(arg_str)}') + if not arg_str: + raise ValueError(f'regex factory can only construct from non-empty `str`, got: {repr(arg_str)}') + # match the values + constructor = None + for c in self: + if c.can_construct(arg_str): + constructor = c + break + # cache the value + self._cache[arg_str] = constructor + if len(self._cache) > 128: + self._cache.popitem() + return constructor + + def has_pattern(self, pattern: Union[str, re.Pattern]) -> bool: + if isinstance(pattern, str): + pattern = re.compile(pattern) + return pattern in self._patterns + + def __len__(self) -> int: + return len(self._constructors) + + def __iter__(self) -> Iterator[RegexConstructor]: + yield from self._constructors + + def append(self, constructor: RegexConstructor): + if not isinstance(constructor, RegexConstructor): + raise TypeError(f'regex factory only accepts {RegexConstructor.__name__} providers.') + if constructor.pattern in self._patterns: + raise RuntimeError(f'regex factory already contains the regex pattern: {repr(constructor.pattern)}') + # append value! + self._patterns.add(constructor.pattern) + self._constructors.append(constructor) + self._cache_dirty = True + + +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # +# Dynamic Registry # +# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + +class RegexRegistry(Registry[V]): + + """ + Registry that allows registering of regex expressions that can be used to + construct values if there is no static value found! + - Regular expressions are checked in the order they are registered. + - `name in registry` checks if any of the expression matches, it does not check for an existing regex + - `len(registry)` returns the number of examples available, each item & regex factory + - `for example in registry` returns the examples available, each item & regex factory should be called if we use these to construct values `registry[example]` + + To check for an already added regex expression, use: + - `has_regex(expr)` + """ + + def __init__(self, name: str): + self._regex_providers = RegexProvidersSearch() + super().__init__(name) + + # --- CORE ... UPDATED WITH LINEAR SEARCH --- # + + @property + def regex_constructors(self) -> List[RegexConstructor]: + return self._regex_providers.regex_constructors + + @property + def regex_examples(self) -> List[str]: + return [constructor.example for constructor in self._regex_providers.regex_constructors] + + @property + def examples(self) -> List[str]: + return [*self.static_examples, *self.regex_examples] + + def __getitem__(self, k: str) -> V: + assert isinstance(k, str), f'invalid key: {repr(k)}, must be a `str`' + # the regex provider is cached so this should be efficient for the same value calls + # -- we do not cache the actual provided value! + if k in self._providers: + return self._getitem(k) + elif self._regex_providers.can_construct(k): + return self._regex_providers.construct(k) + raise KeyError(f'dynamic registry: {repr(self.name)} cannot construct item with key: {repr(k)}. Valid static values: {sorted(self._providers.keys())}. Valid dynamic examples: {[p.example for p in self._regex_providers]}') + + def __setitem__(self, aliases: AliasesHint, v: ProvidedValue[V]) -> NoReturn: + if isinstance(aliases, re.Pattern) or isinstance(v, RegexConstructor): + raise RuntimeError(f'register dynamic values to the dynamic registry: {repr(self.name)} with the `register_regex` or `register_constructor` methods.') + super().__setitem__(aliases, v) + + def __contains__(self, k: K): + if k in self._providers: + return True + if self._regex_providers.can_construct(k): + return True + return False + + def __len__(self) -> int: + return len(self._providers) + len(self._regex_providers) + + def __iter__(self) -> Iterator[K]: + yield from self._providers + yield from (p.example for p in self._regex_providers) + + # --- OVERRIDABLE --- # + + def _check_regex_constructor(self, constructor: RegexConstructor): + pass + + # --- DYNAMIC VALUES --- # + + def has_regex(self, pattern: Union[str, re.Pattern]) -> bool: + return self._regex_providers.has_pattern(pattern) + + def register_constructor(self, constructor: RegexConstructor) -> 'RegexRegistry': + """ + Register a regex constructor + """ + if not isinstance(constructor, RegexConstructor): + raise TypeError(f'dynamic registry: {repr(self.name)} only accepts dynamic {RegexConstructor.__name__}, got: {repr(constructor)}') + self._check_regex_constructor(constructor) + self._regex_providers.append(constructor) + return self + + def register_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None): + """ + Register and create a regex constructor + """ + def _register_wrapper(fn: T) -> T: + self.register_constructor(RegexConstructor(pattern=pattern, example=example, factory_fn=fn)) + return fn + return _register_wrapper if (factory_fn is None) else _register_wrapper(factory_fn) + + def register_missing_constructor(self, constructor: RegexConstructor): + """ + Only register a regex constructor if the pattern does not already exist! + """ + if not self.has_regex(constructor.pattern): + return self.register_constructor(constructor) + + def register_missing_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None): + """ + Only register and create a regex constructor if the pattern does not already exist! + """ + if not self.has_regex(pattern): + return self.register_regex(pattern=pattern, example=example, factory_fn=factory_fn) + elif factory_fn is None: + return lambda fn: fn # dummy wrapper + + # --- MISSING VALUES --- # + + # override from the parent class! + class _RegistrySetMissing(Registry._RegistrySetMissing): + + _registry: 'RegexRegistry' + + def register_constructor(self, constructor: RegexConstructor): + """ + Only register a regex constructor if the pattern does not already exist! + """ + return self._registry.register_missing_constructor(constructor=constructor) + + def register_regex(self, pattern: Union[str, re.Pattern], example: str, factory_fn: Optional[Union[_FactoryFn[V], str]] = None): + """ + Only register and create a regex constructor if the pattern does not already exist! + """ + return self._registry.register_missing_regex(pattern=pattern, example=example, factory_fn=factory_fn, ) # ========================================================================= # diff --git a/disent/schedule/__init__.py b/disent/schedule/__init__.py index c9ea8030..f20d4842 100644 --- a/disent/schedule/__init__.py +++ b/disent/schedule/__init__.py @@ -30,12 +30,17 @@ from ._schedule import CyclicSchedule from ._schedule import LinearSchedule from ._schedule import NoopSchedule +from ._schedule import MultiplySchedule +from ._schedule import FixedValueSchedule from ._schedule import SingleSchedule + # aliases from ._schedule import ClipSchedule as Clip from ._schedule import CosineWaveSchedule as CosineWave from ._schedule import CyclicSchedule as Cyclic from ._schedule import LinearSchedule as Linear from ._schedule import NoopSchedule as Noop +from ._schedule import MultiplySchedule as Multiply +from ._schedule import FixedValueSchedule as FixedValue from ._schedule import SingleSchedule as Single diff --git a/disent/schedule/_schedule.py b/disent/schedule/_schedule.py index 41db86d5..13632676 100644 --- a/disent/schedule/_schedule.py +++ b/disent/schedule/_schedule.py @@ -58,6 +58,52 @@ def compute_value(self, step: int, value): return value +class MultiplySchedule(Schedule): + """ + A schedule that always applies a constant multiplier/ratio to the input value + """ + + # This schedule will always return a constant value! + + def __init__(self, r: float = 1.0): + """ + :param r: The constant ratio of the original value that the schedule will use + """ + self.r = r + + def compute_value(self, step: int, value): + # we always return a constant value! + return value * self.r + + +class FixedValueSchedule(Schedule): + """ + Set a new constant value, instead of using the value passed to `compute_value`. + - We can override config values using this class + """ + + def __init__( + self, + value: float, + schedule: Optional[Schedule] = None, + ): + """ + :param schedule: The wrapped schedule that is passed the new constant value + :param value: The value that should be used to replace the original values from the config. + If `compute_value` is called, the value passed to the function is replaced with this one! + """ + assert (schedule is None) or isinstance(schedule, Schedule) + self.schedule = schedule + self.value = value + + def compute_value(self, step: int, value): + del value + # we override the passed value, and pass in a constant value instead! + if self.schedule is None: + return self.value + else: + return self.schedule(step, self.value) + # ========================================================================= # # Value Schedules # # ========================================================================= # diff --git a/disent/util/deprecate.py b/disent/util/deprecate.py index 96fc654c..75b9a8a8 100644 --- a/disent/util/deprecate.py +++ b/disent/util/deprecate.py @@ -69,7 +69,7 @@ def _get_stack_file_strings(): DEFAULT_TRACEBACK_MODE = 'first' -def deprecated(msg: str, traceback_mode: Optional[str] = None): +def deprecated(msg: str, traceback_mode: Optional[str] = None, fn=None): """ Mark a function or class as deprecated, and print a warning the first time it is used. @@ -114,7 +114,12 @@ def _caller(*args, **kwargs): else: fn = _caller return fn - return _decorator + + # handle function used as decorator, or called directly + if fn is not None: + return _decorator(fn) + else: + return _decorator # ========================================================================= # diff --git a/disent/util/inout/paths.py b/disent/util/inout/paths.py index 1ad76ded..7f6ea5ac 100644 --- a/disent/util/inout/paths.py +++ b/disent/util/inout/paths.py @@ -42,10 +42,33 @@ def modify_file_name(file: Union[str, Path], prefix: str = None, suffix: str = N path = Path(file) assert path.name, f'file name cannot be empty: {repr(path)}, for name: {repr(path.name)}' # create new path - prefix = '' if (prefix is None) else f'{prefix}{sep}' - suffix = '' if (suffix is None) else f'{sep}{suffix}' + prefix = '' if (not prefix) else f'{prefix}{sep}' + suffix = '' if (not suffix) else f'{sep}{suffix}' new_path = path.parent.joinpath(f'{prefix}{path.name}{suffix}') - # return path + # return path with same format as input + return str(new_path) if isinstance(file, str) else new_path + + +def modify_name_keep_ext(file: Union[str, Path], prefix: str = None, suffix: str = None, name_contains_sep: bool = False): + # get path components + path = Path(file) + name = path.name + assert name, f'file name cannot be empty: {repr(path)}, for name: {repr(name)}' + # handle suffix + if suffix: + components = name.rsplit('.', 1) if name_contains_sep else path.name.split('.', 1) + if len(components) >= 2: + [name, ext] = components + name = f'{name}{suffix}.{ext}' + else: + [name] = components + name = f'{name}{suffix}' + # handle prefix + if prefix: + name = f'{prefix}{name}' + # create new path + new_path = path.parent.joinpath(name) + # return path with same format as input return str(new_path) if isinstance(file, str) else new_path diff --git a/disent/util/jit.py b/disent/util/jit.py new file mode 100644 index 00000000..38083628 --- /dev/null +++ b/disent/util/jit.py @@ -0,0 +1,53 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# ========================================================================= # +# Numba Is An Optional Dependency # +# ========================================================================= # + + +def try_njit(*args, **kwargs): + """ + Wrapper around numba.njit + - If numba is installed, then we JIT the decorated function + - If numba is missing, then we do nothing and leave the function untouched! + """ + try: + from numba import njit + except ImportError: + # dummy njit + def njit(*args, **kwargs): + def _wrapper(func): + import warnings + warnings.warn(f'failed to JIT compile: {func}, numba is not installed!') + return func + return _wrapper + # try and JIT compile function! + return njit(*args, **kwargs) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/lightning/callbacks/__init__.py b/disent/util/lightning/callbacks/__init__.py index 60b1dff4..ec300e21 100644 --- a/disent/util/lightning/callbacks/__init__.py +++ b/disent/util/lightning/callbacks/__init__.py @@ -25,8 +25,7 @@ from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic from disent.util.lightning.callbacks._callbacks_base import BaseCallbackTimed -from disent.util.lightning.callbacks._callbacks_pl import LoggerProgressCallback - -from disent.util.lightning.callbacks._callbacks_vae import VaeMetricLoggingCallback -from disent.util.lightning.callbacks._callbacks_vae import VaeLatentCycleLoggingCallback -from disent.util.lightning.callbacks._callbacks_vae import VaeGtDistsLoggingCallback +from disent.util.lightning.callbacks._callback_print_progress import LoggerProgressCallback +from disent.util.lightning.callbacks._callback_log_metrics import VaeMetricLoggingCallback +from disent.util.lightning.callbacks._callback_vis_latents import VaeLatentCycleLoggingCallback +from disent.util.lightning.callbacks._callback_vis_dists import VaeGtDistsLoggingCallback diff --git a/disent/util/lightning/callbacks/_callback_log_metrics.py b/disent/util/lightning/callbacks/_callback_log_metrics.py new file mode 100644 index 00000000..b2fafead --- /dev/null +++ b/disent/util/lightning/callbacks/_callback_log_metrics.py @@ -0,0 +1,129 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import warnings +from typing import Optional +from typing import Sequence + +import pytorch_lightning as pl + +from disent import registry as R +from disent.dataset.data import GroundTruthData +from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic +from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like +from disent.util.lightning.logger_util import log_metrics +from disent.util.lightning.logger_util import wb_log_reduced_summaries +from disent.util.profiling import Timer +from disent.util.strings import colors as c + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Helper # +# ========================================================================= # + + +def _normalized_numeric_metrics(items: dict): + results = {} + for k, v in items.items(): + if isinstance(v, (float, int)): + results[k] = v + else: + try: + results[k] = float(v) + except: + log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}') + return results + + +# ========================================================================= # +# Metrics Callback # +# ========================================================================= # + + +class VaeMetricLoggingCallback(BaseCallbackPeriodic): + + def __init__( + self, + step_end_metrics: Optional[Sequence[str]] = None, + train_end_metrics: Optional[Sequence[str]] = None, + every_n_steps: Optional[int] = None, + begin_first_step: bool = False, + ): + super().__init__(every_n_steps, begin_first_step) + self.step_end_metrics = step_end_metrics if step_end_metrics else [] + self.train_end_metrics = train_end_metrics if train_end_metrics else [] + assert isinstance(self.step_end_metrics, list) + assert isinstance(self.train_end_metrics, list) + assert self.step_end_metrics or self.train_end_metrics, 'No metrics given to step_end_metrics or train_end_metrics' + + def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningModule, metrics: list, is_final=False): + # get dataset and vae framework from trainer and module + dataset, vae = _get_dataset_and_ae_like(trainer, pl_module, unwrap_groundtruth=True) + # check if we need to skip + # TODO: dataset needs to be able to handle wrapped datasets! + if not dataset.is_ground_truth: + warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!') + return + # get padding amount + pad = max(7+len(k) for k in R.METRICS) # I know this is a magic variable... im just OCD + # compute all metrics + for metric in metrics: + if is_final: + log.info(f'| {metric.__name__:<{pad}} - computing...') + with Timer() as timer: + scores = metric(dataset, lambda x: vae.encode(x.to(vae.device))) + metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items()) + log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}') + + # log to trainer + prefix = 'final_metric' if is_final else 'epoch_metric' + prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()} + log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores)) + + # log summary for WANDB + # this is kinda hacky... the above should work for parallel coordinate plots + wb_log_reduced_summaries(trainer.logger, prefixed_scores, reduction='max') + + def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + if self.step_end_metrics: + log.debug('Computing Epoch Metrics:') + with Timer() as timer: + self._compute_metrics_and_log(trainer, pl_module, metrics=self.step_end_metrics, is_final=False) + log.debug(f'Computed Epoch Metrics! {timer.pretty}') + + def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + if self.train_end_metrics: + log.debug('Computing Final Metrics...') + with Timer() as timer: + self._compute_metrics_and_log(trainer, pl_module, metrics=self.train_end_metrics, is_final=True) + log.debug(f'Computed Final Metrics! {timer.pretty}') + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/lightning/callbacks/_callbacks_pl.py b/disent/util/lightning/callbacks/_callback_print_progress.py similarity index 78% rename from disent/util/lightning/callbacks/_callbacks_pl.py rename to disent/util/lightning/callbacks/_callback_print_progress.py index 5e8086aa..f2ee345e 100644 --- a/disent/util/lightning/callbacks/_callbacks_pl.py +++ b/disent/util/lightning/callbacks/_callback_print_progress.py @@ -23,7 +23,6 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging -import warnings import pytorch_lightning as pl @@ -39,7 +38,11 @@ class LoggerProgressCallback(BaseCallbackTimed): - + + def __init__(self, interval: float = 10, log_level: int = logging.INFO): + super().__init__(interval=interval) + self._log_level = log_level + def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, current_time, start_time): # get missing vars trainer_max_epochs = trainer.max_epochs if (trainer.max_epochs is not None) else float('inf') @@ -50,37 +53,43 @@ def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, curren max_epochs = min(trainer_max_epochs, (trainer_max_steps + max_batches - 1) // max_batches) max_steps = min(trainer_max_epochs * max_batches, trainer_max_steps) elapsed_sec = current_time - start_time + # get vars global_step = trainer.global_step + 1 epoch = trainer.current_epoch + 1 if hasattr(trainer, 'batch_idx'): batch = (trainer.batch_idx + 1) else: - # TODO: re-enable this warning but only ever print once! - # warnings.warn('batch_idx missing on pl.Trainer') + # warnings.warn('batch_idx missing on pl.Trainer') # TODO: re-enable this warning but only ever print once! batch = global_step % max_batches # might not be int? + # completion train_pct = global_step / max_steps train_remain_time = elapsed_sec * (1 - train_pct) / train_pct # seconds + # get speed -- TODO: make this a moving average? if global_step >= elapsed_sec: step_speed_str = f'{global_step / elapsed_sec:4.2f}it/s' else: step_speed_str = f'{elapsed_sec / global_step:4.2f}s/it' - # info dict + + # get the metrics and format them info_dict = { k: f'{v:.4g}' if isinstance(v, (int, float)) else f'{v}' - for k, v in trainer.progress_bar_dict.items() - if k != 'v_num' + for k, v in trainer.progress_bar_metrics.items() } + + # sort the keys placing loss entries first sorted_k = sorted(info_dict.keys(), key=lambda k: ('loss' != k.lower(), 'loss' not in k.lower(), k)) - # log - log.info( - f'[{int(elapsed_sec)}s, {step_speed_str}] ' - + f'EPOCH: {epoch}/{max_epochs} - {int(global_step):0{len(str(max_steps))}d}/{max_steps} ' - + f'({int(train_pct * 100):02d}%) [rem. {int(train_remain_time)}s] ' - + f'STEP: {int(batch):{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) ' - + f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}' + + # log everything + log.log( + level=self._log_level, + msg=f'[{int(elapsed_sec)}s, {step_speed_str}] ' + + f'EPOCH: {epoch}/{max_epochs} - {int(global_step):0{len(str(max_steps))}d}/{max_steps} ' + + f'({int(train_pct * 100):02d}%) [rem. {int(train_remain_time)}s] ' + + f'STEP: {int(batch):{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) ' + + f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}' ) diff --git a/disent/util/lightning/callbacks/_callback_vis_dists.py b/disent/util/lightning/callbacks/_callback_vis_dists.py new file mode 100644 index 00000000..aa33646f --- /dev/null +++ b/disent/util/lightning/callbacks/_callback_vis_dists.py @@ -0,0 +1,329 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import pytorch_lightning as pl +import torch + +# TODO: wandb and matplotlib are not in requirements +import matplotlib.pyplot as plt +import wandb + +import disent.util.strings.colors as c +from disent.dataset import DisentDataset +from disent.dataset.data import GroundTruthData +from disent.frameworks.ae import Ae +from disent.frameworks.vae import Vae +from disent.util.function import wrapped_partial +from disent.util.iters import chunked +from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic +from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like +from disent.util.lightning.logger_util import wb_log_metrics +from disent.util.profiling import Timer +from disent.util.seeds import TempNumpySeed +from disent.util.visualize.plot import plt_subplots_imshow + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Helper Functions # +# ========================================================================= # + + +# helper +def _to_dmat( + size: int, + i_a: np.ndarray, + i_b: np.ndarray, + dists: Union[torch.Tensor, np.ndarray], +) -> np.ndarray: + if isinstance(dists, torch.Tensor): + dists = dists.detach().cpu().numpy() + # checks + assert i_a.ndim == 1 + assert i_a.shape == i_b.shape + assert i_a.shape == dists.shape + # compute + dmat = np.zeros([size, size], dtype='float32') + dmat[i_a, i_b] = dists + dmat[i_b, i_a] = dists + return dmat + + +_AE_DIST_NAMES = ('x', 'z', 'x_recon') +_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon') + + +@torch.no_grad() +def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor): + # feed forware + z_a, z_b = ae.encode(x_a), ae.encode(x_b) + r_a, r_b = ae.decode(z_a), ae.decode(z_b) + # distances + return [ + ae.recon_handler.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist + ae.recon_handler.compute_pairwise_loss(r_a, r_b), + ] + + +@torch.no_grad() +def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor): + from torch.distributions import kl_divergence + # feed forward + (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b) + z_a, z_b = z_post_a.mean, z_post_b.mean + r_a, r_b = vae.decode(z_a), vae.decode(z_b) + # dists + kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a) + # distances + return [ + vae.recon_handler.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist + vae.recon_handler._pairwise_reduce(kl_ab), + vae.recon_handler.compute_pairwise_loss(r_a, r_b), + ] + + +def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]: + # get aggregate function + if isinstance(model, Vae): + dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model) + elif isinstance(model, Ae): + dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model) + else: + dists_names, dists_fn = None, None + return dists_names, dists_fn + + +@torch.no_grad() +def _collect_dists_subbatches(dists_fn: Callable[[object, object], Sequence[Sequence[float]]], batch: torch.Tensor, i_a: np.ndarray, i_b: np.ndarray, batch_size: int = 64): + # feed forward + results = [] + for idxs in chunked(np.stack([i_a, i_b], axis=-1), chunk_size=batch_size): + ia, ib = idxs.T + x_a, x_b = batch[ia], batch[ib] + # feed forward + data = dists_fn(x_a, x_b) + results.append(data) + return [torch.cat(r, dim=0) for r in zip(*results)] + + +def _compute_and_collect_dists( + dataset: DisentDataset, + dists_fn, + dists_names: Sequence[str], + traversal_repeats: int = 100, + batch_size: int = 32, + include_gt_factor_dists: bool = True, + transform_batch: Callable[[object], object] = None, + data_mode: str = 'input', +) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: + assert traversal_repeats > 0 + gt_data = dataset.gt_data + # generate + f_grid = [] + # generate + for f_idx, f_size in enumerate(gt_data.factor_sizes): + # save for the current factor (traversal_repeats, len(names), len(i_a)) + f_dists = [] + # upper triangle excluding diagonal + i_a, i_b = np.triu_indices(f_size, k=1) + # repeat over random traversals + for i in range(traversal_repeats): + # get random factor traversal + factors = gt_data.sample_random_factor_traversal(f_idx=f_idx) + indices = gt_data.pos_to_idx(factors) + # load data + batch = dataset.dataset_batch_from_indices(indices, data_mode) + if transform_batch is not None: + batch = transform_batch(batch) + # feed forward & compute dists -- (len(names), len(i_a)) + dists = _collect_dists_subbatches(dists_fn=dists_fn, batch=batch, i_a=i_a, i_b=i_b, batch_size=batch_size) + assert len(dists) == len(dists_names) + # distances + f_dists.append(dists) + # aggregate all dists into distances matrices for current factor + f_dmats = [ + _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=torch.stack(dists, dim=0).mean(dim=0)) + for dists in zip(*f_dists) + ] + # handle factors + if include_gt_factor_dists: + i_dmat = _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=np.abs(factors[i_a] - factors[i_b]).sum(axis=-1)) + f_dmats = [i_dmat, *f_dmats] + # append data + f_grid.append(f_dmats) + # handle factors + if include_gt_factor_dists: + dists_names = ('factors', *dists_names) + # done + return tuple(dists_names), f_grid + + +def compute_factor_distances( + dataset: DisentDataset, + dists_fn, + dists_names: Sequence[str], + traversal_repeats: int = 100, + batch_size: int = 32, + include_gt_factor_dists: bool = True, + transform_batch: Callable[[object], object] = None, + seed: Optional[int] = 777, + data_mode: str = 'input', +) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: + # log this callback + gt_data = dataset.gt_data + log.info(f'| {gt_data.name} - computing factor distances...') + # compute various distances matrices for each factor + with Timer() as timer, TempNumpySeed(seed): + dists_names, f_grid = _compute_and_collect_dists( + dataset=dataset, + dists_fn=dists_fn, + dists_names=dists_names, + traversal_repeats=traversal_repeats, + batch_size=batch_size, + include_gt_factor_dists=include_gt_factor_dists, + transform_batch=transform_batch, + data_mode=data_mode, + ) + # log this callback! + log.info(f'| {gt_data.name} - computed factor distances! time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST}') + return dists_names, f_grid + + +def plt_factor_distances( + gt_data: GroundTruthData, + f_grid: List[List[np.ndarray]], + dists_names: Sequence[str], + title: str, + plt_block_size: float = 1.25, + plt_transpose: bool = False, + plt_cmap='Blues', +): + # plot information + imshow_kwargs = dict(cmap=plt_cmap) + figsize = (plt_block_size*len(f_grid[0]), plt_block_size * gt_data.num_factors) + # plot! + if not plt_transpose: + fig, axs = plt_subplots_imshow(grid=f_grid, col_labels=dists_names, row_labels=gt_data.factor_names, figsize=figsize, title=title, imshow_kwargs=imshow_kwargs) + else: + fig, axs = plt_subplots_imshow(grid=list(zip(*f_grid)), col_labels=gt_data.factor_names, row_labels=dists_names, figsize=figsize[::-1], title=title, imshow_kwargs=imshow_kwargs) + # done + return fig, axs + + +# ========================================================================= # +# Data Dists Visualisation Callback # +# ========================================================================= # + + +class VaeGtDistsLoggingCallback(BaseCallbackPeriodic): + + def __init__( + self, + seed: Optional[int] = 7777, + every_n_steps: Optional[int] = None, + traversal_repeats: int = 100, + begin_first_step: bool = False, + plt_block_size: float = 1.25, + plt_show: bool = False, + plt_transpose: bool = False, + log_wandb: bool = True, # TODO: detect this automatically? + batch_size: int = 128, + include_factor_dists: bool = True, + ): + assert traversal_repeats > 0 + self._traversal_repeats = traversal_repeats + self._seed = seed + self._plt_block_size = plt_block_size + self._plt_show = plt_show + self._log_wandb = log_wandb + self._include_gt_factor_dists = include_factor_dists + self._transpose_plot = plt_transpose + self._batch_size = batch_size + super().__init__(every_n_steps, begin_first_step) + + @torch.no_grad() + def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + # exit early + if not (self._plt_show or self._log_wandb): + log.warning(f'skipping {self.__class__.__name__} neither `plt_show` or `log_wandb` is `True`!') + return + # get dataset and vae framework from trainer and module + dataset, vae = _get_dataset_and_ae_like(trainer, pl_module, unwrap_groundtruth=True) + # exit early + if not dataset.is_ground_truth: + log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!') + return + # get aggregate function + dists_names, dists_fn = _get_dists_fn(vae) + if (dists_names is None) or (dists_fn is None): + log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') + return + # compute various distances matrices for each factor + dists_names, f_grid = compute_factor_distances( + dataset=dataset, + dists_fn=dists_fn, + dists_names=dists_names, + traversal_repeats=self._traversal_repeats, + batch_size=self._batch_size, + include_gt_factor_dists=self._include_gt_factor_dists, + transform_batch=lambda batch: batch.to(vae.device), + seed=self._seed, + data_mode='input', + ) + # plot these results + fig, axs = plt_factor_distances( + gt_data=dataset.gt_data, + f_grid=f_grid, + dists_names=dists_names, + title=f'{vae.__class__.__name__}: {dataset.gt_data.name.capitalize()} Distances', + plt_block_size=self._plt_block_size, + plt_transpose=self._transpose_plot, + plt_cmap='Blues', + ) + # show the plot + if self._plt_show: + plt.show() + # log the plot to wandb + if self._log_wandb: + wb_log_metrics(trainer.logger, { + 'factor_distances': wandb.Image(fig) + }) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/lightning/callbacks/_callback_vis_latents.py b/disent/util/lightning/callbacks/_callback_vis_latents.py new file mode 100644 index 00000000..a33aae0c --- /dev/null +++ b/disent/util/lightning/callbacks/_callback_vis_latents.py @@ -0,0 +1,336 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import pytorch_lightning as pl +import torch + +from disent.dataset import DisentDataset +from disent.frameworks.ae import Ae +from disent.frameworks.vae import Vae +from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic +from disent.util.lightning.callbacks._helper import _get_dataset_and_ae_like +from disent.util.lightning.logger_util import wb_log_metrics +from disent.util.seeds import TempNumpySeed +from disent.util.visualize.vis_img import torch_to_images +from disent.util.visualize.vis_latents import make_decoded_latent_cycles +from disent.util.visualize.vis_util import make_animated_image_grid +from disent.util.visualize.vis_util import make_image_grid + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# Helper # +# ========================================================================= # + + +MinMaxHint = Optional[Union[int, Literal['auto']]] +MeanStdHint = Optional[Union[Tuple[float, ...], float]] + + +def get_vis_min_max( + recon_min: MinMaxHint = None, + recon_max: MinMaxHint = None, + recon_mean: MeanStdHint = None, + recon_std: MeanStdHint = None, +) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: + # check recon_min and recon_max + if (recon_min is not None) or (recon_max is not None): + if (recon_mean is not None) or (recon_std is not None): + raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') + if (recon_min is None) or (recon_max is None): + raise ValueError('both recon_min & recon_max must be specified') + # check strings + if isinstance(recon_min, str) or isinstance(recon_max, str): + if not (isinstance(recon_min, str) and isinstance(recon_max, str)): + raise ValueError('both recon_min & recon_max must be "auto" if one is "auto"') + return None, None # "auto" -> None + # check recon_mean and recon_std + elif (recon_mean is not None) or (recon_std is not None): + if (recon_min is not None) or (recon_max is not None): + raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') + if (recon_mean is None) or (recon_std is None): + raise ValueError('both recon_mean & recon_std must be specified') + # set values: + # | ORIG: [0, 1] + # | TRANSFORM: (x - mean) / std -> [(0-mean)/std, (1-mean)/std] + # | REVERT: (x - min) / (max - min) -> [0, 1] + # | min=(0-mean)/std, max=(1-mean)/std + recon_mean, recon_std = np.array(recon_mean, dtype='float32'), np.array(recon_std, dtype='float32') + recon_min = np.divide(0 - recon_mean, recon_std) + recon_max = np.divide(1 - recon_mean, recon_std) + # set defaults + if recon_min is None: recon_min = 0.0 + if recon_max is None: recon_max = 1.0 + # change type + recon_min = np.array(recon_min) + recon_max = np.array(recon_max) + assert recon_min.ndim in (0, 1) + assert recon_max.ndim in (0, 1) + # checks + assert np.all(recon_min < recon_max), f'recon_min={recon_min} must be less than recon_max={recon_max}' + return recon_min, recon_max + + +# ========================================================================= # +# Latent Visualisation Callback # +# ========================================================================= # + + +class VaeLatentCycleLoggingCallback(BaseCallbackPeriodic): + + def __init__( + self, + seed: Optional[int] = 7777, + every_n_steps: Optional[int] = None, + begin_first_step: bool = False, + num_frames: int = 17, + mode: str = 'minmax_interval_cycle', + num_stats_samples: int = 64, + log_wandb: bool = True, # TODO: detect this automatically? + wandb_mode: str = 'both', + wandb_fps: int = 4, + plt_show: bool = False, + plt_block_size: float = 1.0, + # recon_min & recon_max + recon_min: MinMaxHint = None, # scale data in this range [min, max] to [0, 1] + recon_max: MinMaxHint = None, # scale data in this range [min, max] to [0, 1] + recon_mean: MeanStdHint = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] + recon_std: MeanStdHint = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] + ): + super().__init__(every_n_steps, begin_first_step) + self._seed = seed + self._mode = mode + self._num_stats_samples = num_stats_samples + self._plt_show = plt_show + self._plt_block_size = plt_block_size + self._log_wandb = log_wandb + self._wandb_mode = wandb_mode + self._num_frames = num_frames + self._fps = wandb_fps + # checks + assert wandb_mode in {'img', 'vid', 'both'}, f'invalid wandb_mode={repr(wandb_mode)}, must be one of: ("img", "vid", "both")' + # normalize + self._recon_min, self._recon_max = get_vis_min_max( + recon_min=recon_min, + recon_max=recon_max, + recon_mean=recon_mean, + recon_std=recon_std, + ) + + @torch.no_grad() + def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + # exit early + if not (self._plt_show or self._log_wandb): + log.warning(f'skipping {self.__class__.__name__} neither `plt_show` or `log_wandb` is `True`!') + return + + # feed forward and visualise everything! + stills, animation, image = self.get_visualisations(trainer, pl_module) + + # log video -- none, img, vid, both + # TODO: there might be a memory leak in making the video below? Or there could be one in the actual DSPRITES dataset? memory usage seems to be very high and increase on this dataset. + if self._log_wandb: + import wandb + wandb_items = {} + if self._wandb_mode in ('img', 'both'): wandb_items[f'{self._mode}_img'] = wandb.Image(image) + if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self._mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'), + wb_log_metrics(trainer.logger, wandb_items) + + # log locally + if self._plt_show: + from matplotlib import pyplot as plt + fig, ax = plt.subplots(1, 1, figsize=(self._plt_block_size*stills.shape[1], self._plt_block_size*stills.shape[0])) + ax.imshow(image) + ax.axis('off') + fig.tight_layout() + plt.show() + + def get_visualisations( + self, + trainer_or_dataset: Union[pl.Trainer, DisentDataset], + pl_module: pl.LightningModule, + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]]: + return self.generate_visualisations( + trainer_or_dataset, + pl_module, + seed=self._seed, + num_frames=self._num_frames, + mode=self._mode, + num_stats_samples=self._num_stats_samples, + recon_min=self._recon_min, + recon_max=self._recon_max, + recon_mean=None, + recon_std=None, + ) + + @classmethod + def generate_visualisations( + cls, + trainer_or_dataset: Union[pl.Trainer, DisentDataset], + pl_module: pl.LightningModule, + seed: Optional[int] = 7777, + num_frames: int = 17, + mode: str = 'fitted_gaussian_cycle', + num_stats_samples: int = 64, + # recon_min & recon_max + recon_min: MinMaxHint = None, + recon_max: MinMaxHint = None, + recon_mean: MeanStdHint = None, + recon_std: MeanStdHint = None, + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor, torch.Tensor]]: + # normalize + recon_min, recon_max = get_vis_min_max( + recon_min=recon_min, + recon_max=recon_max, + recon_mean=recon_mean, + recon_std=recon_std, + ) + + # get dataset and vae framework from trainer and module + dataset, vae = _get_dataset_and_ae_like(trainer_or_dataset, pl_module, unwrap_groundtruth=True) + + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # + # generate traversal + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # + + # get random sample of z_means and z_logvars for computing the range of values for the latent_cycle + with TempNumpySeed(seed): + batch, indices = dataset.dataset_sample_batch(num_stats_samples, mode='input', replace=True, return_indices=True) # replace just in case the dataset it tiny + batch = batch.to(vae.device) + + # get representations + if isinstance(vae, Vae): + # variational auto-encoder + ds_posterior, ds_prior = vae.encode_dists(batch) + zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance) + elif isinstance(vae, Ae): + # auto-encoder + zs_mean = vae.encode(batch) + zs_logvar = torch.ones_like(zs_mean) + else: + log.warning(f'cannot run {cls.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') + return + + # get min and max if auto + if (recon_min is None) or (recon_max is None): + if recon_min is None: recon_min = float(torch.amin(batch).cpu()) + if recon_max is None: recon_max = float(torch.amax(batch).cpu()) + log.info(f'auto visualisation min: {recon_min} and max: {recon_max} obtained from {len(batch)} samples') + + # produce latent cycle still images & convert them to images + stills = make_decoded_latent_cycles(vae.decode, zs_mean, zs_logvar, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=vae.device)[0] + stills = torch_to_images(stills, in_dims='CHW', out_dims='HWC', in_min=recon_min, in_max=recon_max, always_rgb=True, to_numpy=True) + + # generate the video frames and image grid from the stills + # - TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves + # - TODO: should this not use `visualize_dataset_traversal`? + frames = make_animated_image_grid(stills, pad=4, border=True, bg_color=None) + image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4, border=True, bg_color=None) + + # done + return stills, frames, image + + + +# TODO: FIX THIS ERROR: +# [2022-03-12 11:44:59,661][experiment.util.run_utils][ERROR] - exiting: experiment error | [Errno 12] Cannot allocate memory +# Traceback (most recent call last): +# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 462, in hydra_main +# run_action(cfg) +# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 378, in run_action +# action(cfg) +# File "/home-mscluster/nmichlo/workspace/research/disent/experiment/run.py", line 351, in action_train +# trainer.fit(framework, datamodule=datamodule) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit +# self._call_and_handle_interrupt( +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt +# return trainer_fn(*args, **kwargs) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl +# self._run(model, ckpt_path=ckpt_path) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run +# self._dispatch() +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch +# self.training_type_plugin.start_training(self) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training +# self._results = trainer.run_stage() +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage +# return self._run_train() +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1312, in _run_train +# self.fit_loop.run() +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run +# self.advance(*args, **kwargs) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance +# self.epoch_loop.run(data_fetcher) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run +# self.advance(*args, **kwargs) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 221, in advance +# self.trainer.call_hook("on_batch_end") +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1477, in call_hook +# callback_fx(*args, **kwargs) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 163, in on_batch_end +# callback.on_batch_end(self, self.lightning_module) +# File "/home-mscluster/nmichlo/workspace/research/disent/disent/util/lightning/callbacks/_callbacks_base.py", line 57, in on_batch_end +# self.do_step(trainer, pl_module) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context +# return func(*args, **kwargs) +# File "/home-mscluster/nmichlo/workspace/research/disent/disent/util/lightning/callbacks/_callback_vis_latents.py", line 165, in do_step +# if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self._mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'), +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/wandb/sdk/data_types.py", line 1232, in __init__ +# self.encode() +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/wandb/sdk/data_types.py", line 1255, in encode +# clip.write_videofile(filename, **kwargs) +# File "", line 2, in write_videofile +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 54, in requires_duration +# return f(clip, *a, **k) +# File "", line 2, in write_videofile +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 135, in use_clip_fps_by_default +# return f(clip, *new_a, **new_kw) +# File "", line 2, in write_videofile +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/decorators.py", line 22, in convert_masks_to_RGB +# return f(clip, *a, **k) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/VideoClip.py", line 300, in write_videofile +# ffmpeg_write_video(self, filename, fps, codec, +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/io/ffmpeg_writer.py", line 213, in ffmpeg_write_video +# with FFMPEG_VideoWriter(filename, clip.size, fps, codec = codec, +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/site-packages/moviepy/video/io/ffmpeg_writer.py", line 129, in __init__ +# self.proc = sp.Popen(cmd, **popen_params) +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/subprocess.py", line 951, in __init__ +# self._execute_child(args, executable, preexec_fn, close_fds, +# File "/home-mscluster/nmichlo/installed/pyenv/versions/miniconda3-latest/envs/disent-conda/lib/python3.9/subprocess.py", line 1754, in _execute_child +# self.pid = _posixsubprocess.fork_exec( +# OSError: [Errno 12] Cannot allocate memory + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/lightning/callbacks/_callbacks_base.py b/disent/util/lightning/callbacks/_callbacks_base.py index db380ec1..3385b8f7 100644 --- a/disent/util/lightning/callbacks/_callbacks_base.py +++ b/disent/util/lightning/callbacks/_callbacks_base.py @@ -24,6 +24,8 @@ import logging import time +from typing import Optional + import pytorch_lightning as pl @@ -37,7 +39,8 @@ class BaseCallbackPeriodic(pl.Callback): - def __init__(self, every_n_steps=None, begin_first_step=False): + def __init__(self, every_n_steps: Optional[int] = None, begin_first_step: bool = False): + assert (every_n_steps is None) or (isinstance(every_n_steps, int) and every_n_steps > 0), f'`every_n_steps` must be None or an integer greater than zero, got: {repr(every_n_steps)}' self.every_n_steps = every_n_steps self.begin_first_step = begin_first_step @@ -59,7 +62,8 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): class BaseCallbackTimed(pl.Callback): - def __init__(self, interval=10): + def __init__(self, interval: float = 10): + assert interval > 0, f'The interval must be greater than zero, got: {repr(interval)}' self._last_time = 0 self._interval = interval self._start_time = time.time() diff --git a/disent/util/lightning/callbacks/_callbacks_vae.py b/disent/util/lightning/callbacks/_callbacks_vae.py deleted file mode 100644 index e846b86d..00000000 --- a/disent/util/lightning/callbacks/_callbacks_vae.py +++ /dev/null @@ -1,639 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import warnings -from typing import Callable -from typing import List -from typing import Literal -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union - -import numpy as np -import pytorch_lightning as pl -import torch - -from pytorch_lightning.trainer.supporters import CombinedLoader -from torch.utils.data.dataloader import default_collate -from tqdm import tqdm - -import disent.metrics -import disent.util.strings.colors as c -from disent.dataset import DisentDataset -from disent.dataset.data import GroundTruthData -from disent.frameworks.ae import Ae -from disent.frameworks.helper.reconstructions import make_reconstruction_loss -from disent.frameworks.helper.reconstructions import ReconLossHandler -from disent.frameworks.vae import Vae -from disent.util.function import wrapped_partial -from disent.util.iters import chunked -from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic -from disent.util.lightning.logger_util import log_metrics -from disent.util.lightning.logger_util import wb_log_metrics -from disent.util.lightning.logger_util import wb_log_reduced_summaries -from disent.util.profiling import Timer -from disent.util.seeds import TempNumpySeed -from disent.util.visualize.plot import plt_subplots_imshow -from disent.util.visualize.vis_model import latent_cycle_grid_animation -from disent.util.visualize.vis_util import make_image_grid - -# TODO: wandb and matplotlib are not in requirements -import matplotlib.pyplot as plt -import wandb - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Helper Functions # -# ========================================================================= # - - -def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule, unwrap_groundtruth: bool = False) -> (DisentDataset, Ae): - # TODO: improve handling! - assert isinstance(pl_module, Ae), f'{pl_module.__class__} is not an instance of {Ae}' - # get dataset - if hasattr(trainer, 'datamodule') and (trainer.datamodule is not None): - assert hasattr(trainer.datamodule, 'dataset_train_noaug') # TODO: this is for experiments, another way of handling this should be added - dataset = trainer.datamodule.dataset_train_noaug - elif hasattr(trainer, 'train_dataloader') and (trainer.train_dataloader is not None): - if isinstance(trainer.train_dataloader, CombinedLoader): - dataset = trainer.train_dataloader.loaders.dataset - else: - raise RuntimeError(f'invalid trainer.train_dataloader: {trainer.train_dataloader}') - else: - raise RuntimeError('could not retrieve dataset! please report this...') - # check dataset - assert isinstance(dataset, DisentDataset), f'retrieved dataset is not an {DisentDataset.__name__}' - # unwarp dataset - if unwrap_groundtruth: - if dataset.is_wrapped_gt_data: - old_dataset, dataset = dataset, dataset.unwrapped_disent_dataset() - warnings.warn(f'Unwrapped ground truth dataset returned! {type(old_dataset.data).__name__} -> {type(dataset.data).__name__}') - # done checks - return dataset, pl_module - - -# ========================================================================= # -# Vae Framework Callbacks # -# ========================================================================= # - -# helper -def _to_dmat( - size: int, - i_a: np.ndarray, - i_b: np.ndarray, - dists: Union[torch.Tensor, np.ndarray], -) -> np.ndarray: - if isinstance(dists, torch.Tensor): - dists = dists.detach().cpu().numpy() - # checks - assert i_a.ndim == 1 - assert i_a.shape == i_b.shape - assert i_a.shape == dists.shape - # compute - dmat = np.zeros([size, size], dtype='float32') - dmat[i_a, i_b] = dists - dmat[i_b, i_a] = dists - return dmat - - -_AE_DIST_NAMES = ('x', 'z', 'x_recon') -_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon') - - -@torch.no_grad() -def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor): - # feed forware - z_a, z_b = ae.encode(x_a), ae.encode(x_b) - r_a, r_b = ae.decode(z_a), ae.decode(z_b) - # distances - return [ - ae.recon_handler.compute_pairwise_loss(x_a, x_b), - torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist - ae.recon_handler.compute_pairwise_loss(r_a, r_b), - ] - - -@torch.no_grad() -def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor): - from torch.distributions import kl_divergence - # feed forward - (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b) - z_a, z_b = z_post_a.mean, z_post_b.mean - r_a, r_b = vae.decode(z_a), vae.decode(z_b) - # dists - kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a) - # distances - return [ - vae.recon_handler.compute_pairwise_loss(x_a, x_b), - torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist - vae.recon_handler._pairwise_reduce(kl_ab), - vae.recon_handler.compute_pairwise_loss(r_a, r_b), - ] - - -def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]: - # get aggregate function - if isinstance(model, Vae): - dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model) - elif isinstance(model, Ae): - dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model) - else: - dists_names, dists_fn = None, None - return dists_names, dists_fn - - -@torch.no_grad() -def _collect_dists_subbatches(dists_fn: Callable[[object, object], Sequence[Sequence[float]]], batch: torch.Tensor, i_a: np.ndarray, i_b: np.ndarray, batch_size: int = 64): - # feed forward - results = [] - for idxs in chunked(np.stack([i_a, i_b], axis=-1), chunk_size=batch_size): - ia, ib = idxs.T - x_a, x_b = batch[ia], batch[ib] - # feed forward - data = dists_fn(x_a, x_b) - results.append(data) - return [torch.cat(r, dim=0) for r in zip(*results)] - - -def _compute_and_collect_dists( - dataset: DisentDataset, - dists_fn, - dists_names: Sequence[str], - traversal_repeats: int = 100, - batch_size: int = 32, - include_gt_factor_dists: bool = True, - transform_batch: Callable[[object], object] = None, - data_mode: str = 'input', -) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: - assert traversal_repeats > 0 - gt_data = dataset.gt_data - # generate - f_grid = [] - # generate - for f_idx, f_size in enumerate(gt_data.factor_sizes): - # save for the current factor (traversal_repeats, len(names), len(i_a)) - f_dists = [] - # upper triangle excluding diagonal - i_a, i_b = np.triu_indices(f_size, k=1) - # repeat over random traversals - for i in range(traversal_repeats): - # get random factor traversal - factors = gt_data.sample_random_factor_traversal(f_idx=f_idx) - indices = gt_data.pos_to_idx(factors) - # load data - batch = dataset.dataset_batch_from_indices(indices, data_mode) - if transform_batch is not None: - batch = transform_batch(batch) - # feed forward & compute dists -- (len(names), len(i_a)) - dists = _collect_dists_subbatches(dists_fn=dists_fn, batch=batch, i_a=i_a, i_b=i_b, batch_size=batch_size) - assert len(dists) == len(dists_names) - # distances - f_dists.append(dists) - # aggregate all dists into distances matrices for current factor - f_dmats = [ - _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=torch.stack(dists, dim=0).mean(dim=0)) - for dists in zip(*f_dists) - ] - # handle factors - if include_gt_factor_dists: - i_dmat = _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=np.abs(factors[i_a] - factors[i_b]).sum(axis=-1)) - f_dmats = [i_dmat, *f_dmats] - # append data - f_grid.append(f_dmats) - # handle factors - if include_gt_factor_dists: - dists_names = ('factors', *dists_names) - # done - return tuple(dists_names), f_grid - - -def compute_factor_distances( - dataset: DisentDataset, - dists_fn, - dists_names: Sequence[str], - traversal_repeats: int = 100, - batch_size: int = 32, - include_gt_factor_dists: bool = True, - transform_batch: Callable[[object], object] = None, - seed: Optional[int] = 777, - data_mode: str = 'input', -) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: - # log this callback - gt_data = dataset.gt_data - log.info(f'| {gt_data.name} - computing factor distances...') - # compute various distances matrices for each factor - with Timer() as timer, TempNumpySeed(seed): - dists_names, f_grid = _compute_and_collect_dists( - dataset=dataset, - dists_fn=dists_fn, - dists_names=dists_names, - traversal_repeats=traversal_repeats, - batch_size=batch_size, - include_gt_factor_dists=include_gt_factor_dists, - transform_batch=transform_batch, - data_mode=data_mode, - ) - # log this callback! - log.info(f'| {gt_data.name} - computed factor distances! time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST}') - return dists_names, f_grid - - -def plt_factor_distances( - gt_data: GroundTruthData, - f_grid: List[List[np.ndarray]], - dists_names: Sequence[str], - title: str, - plt_block_size: float = 1.25, - plt_transpose: bool = False, - plt_cmap='Blues', -): - # plot information - imshow_kwargs = dict(cmap=plt_cmap) - figsize = (plt_block_size*len(f_grid[0]), plt_block_size * gt_data.num_factors) - # plot! - if not plt_transpose: - fig, axs = plt_subplots_imshow(grid=f_grid, col_labels=dists_names, row_labels=gt_data.factor_names, figsize=figsize, title=title, imshow_kwargs=imshow_kwargs) - else: - fig, axs = plt_subplots_imshow(grid=list(zip(*f_grid)), col_labels=gt_data.factor_names, row_labels=dists_names, figsize=figsize[::-1], title=title, imshow_kwargs=imshow_kwargs) - # done - return fig, axs - - -class VaeGtDistsLoggingCallback(BaseCallbackPeriodic): - - def __init__( - self, - seed: Optional[int] = 7777, - every_n_steps: Optional[int] = None, - traversal_repeats: int = 100, - begin_first_step: bool = False, - plt_block_size: float = 1.25, - plt_show: bool = False, - plt_transpose: bool = False, - log_wandb: bool = True, - batch_size: int = 128, - include_factor_dists: bool = True, - ): - assert traversal_repeats > 0 - self._traversal_repeats = traversal_repeats - self._seed = seed - self._plt_block_size = plt_block_size - self._plt_show = plt_show - self._log_wandb = log_wandb - self._include_gt_factor_dists = include_factor_dists - self._transpose_plot = plt_transpose - self._batch_size = batch_size - super().__init__(every_n_steps, begin_first_step) - - @torch.no_grad() - def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - # get dataset and vae framework from trainer and module - dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True) - # exit early - if not dataset.is_ground_truth: - log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!') - return - # get aggregate function - dists_names, dists_fn = _get_dists_fn(vae) - if (dists_names is None) or (dists_fn is None): - log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') - return - # compute various distances matrices for each factor - dists_names, f_grid = compute_factor_distances( - dataset=dataset, - dists_fn=dists_fn, - dists_names=dists_names, - traversal_repeats=self._traversal_repeats, - batch_size=self._batch_size, - include_gt_factor_dists=self._include_gt_factor_dists, - transform_batch=lambda batch: batch.to(vae.device), - seed=self._seed, - data_mode='input', - ) - # plot these results - fig, axs = plt_factor_distances( - gt_data=dataset.gt_data, - f_grid=f_grid, - dists_names=dists_names, - title=f'{vae.__class__.__name__}: {dataset.gt_data.name.capitalize()} Distances', - plt_block_size=self._plt_block_size, - plt_transpose=self._transpose_plot, - plt_cmap='Blues', - ) - # show the plot - if self._plt_show: - plt.show() - # log the plot to wandb - if self._log_wandb: - wb_log_metrics(trainer.logger, { - 'factor_distances': wandb.Image(fig) - }) - - -def _normalize_min_max_mean_std_to_min_max(recon_min, recon_max, recon_mean, recon_std) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: - # check recon_min and recon_max - if (recon_min is not None) or (recon_max is not None): - if (recon_mean is not None) or (recon_std is not None): - raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') - if (recon_min is None) or (recon_max is None): - raise ValueError('both recon_min & recon_max must be specified') - # check strings - if isinstance(recon_min, str) or isinstance(recon_max, str): - if not (isinstance(recon_min, str) and isinstance(recon_max, str)): - raise ValueError('both recon_min & recon_max must be "auto" if one is "auto"') - return None, None - # check recon_mean and recon_std - elif (recon_mean is not None) or (recon_std is not None): - if (recon_min is not None) or (recon_max is not None): - raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') - if (recon_mean is None) or (recon_std is None): - raise ValueError('both recon_mean & recon_std must be specified') - # set values: - # | ORIG: [0, 1] - # | TRANSFORM: (x - mean) / std -> [(0-mean)/std, (1-mean)/std] - # | REVERT: (x - min) / (max - min) -> [0, 1] - # | min=(0-mean)/std, max=(1-mean)/std - recon_mean, recon_std = np.array(recon_mean, dtype='float32'), np.array(recon_std, dtype='float32') - recon_min = np.divide(0 - recon_mean, recon_std) - recon_max = np.divide(1 - recon_mean, recon_std) - # set defaults - if recon_min is None: recon_min = 0.0 - if recon_max is None: recon_max = 0.0 - # change type - recon_min = np.array(recon_min) - recon_max = np.array(recon_max) - assert recon_min.ndim in (0, 1) - assert recon_max.ndim in (0, 1) - # checks - assert np.all(recon_min < np.all(recon_max)), f'recon_min={recon_min} must be less than recon_max={recon_max}' - return recon_min, recon_max - - -class VaeLatentCycleLoggingCallback(BaseCallbackPeriodic): - - def __init__( - self, - seed: Optional[int] = 7777, - every_n_steps: Optional[int] = None, - begin_first_step: bool = False, - num_frames: int = 17, - mode: str = 'fitted_gaussian_cycle', - wandb_mode: str = 'both', - wandb_fps: int = 4, - plt_show: bool = False, - plt_block_size: float = 1.0, - # recon_min & recon_max - recon_min: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1] - recon_max: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1] - recon_mean: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] - recon_std: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] - ): - super().__init__(every_n_steps, begin_first_step) - self.seed = seed - self.mode = mode - self.plt_show = plt_show - self.plt_block_size = plt_block_size - self._wandb_mode = wandb_mode - self._num_frames = num_frames - self._fps = wandb_fps - # checks - assert wandb_mode in {'none', 'img', 'vid', 'both'}, f'invalid wandb_mode={repr(wandb_mode)}, must be one of: ("none", "img", "vid", "both")' - # normalize - self._recon_min, self._recon_max = _normalize_min_max_mean_std_to_min_max( - recon_min=recon_min, - recon_max=recon_max, - recon_mean=recon_mean, - recon_std=recon_std, - ) - - - def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - # get dataset and vae framework from trainer and module - dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True) - - # TODO: should this not use `visualize_dataset_traversal`? - - with torch.no_grad(): - # get random sample of z_means and z_logvars for computing the range of values for the latent_cycle - with TempNumpySeed(self.seed): - batch = dataset.dataset_sample_batch(64, mode='input').to(vae.device) - - # get representations - if isinstance(vae, Vae): - # variational auto-encoder - ds_posterior, ds_prior = vae.encode_dists(batch) - zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance) - elif isinstance(vae, Ae): - # auto-encoder - zs_mean = vae.encode(batch) - zs_logvar = torch.ones_like(zs_mean) - else: - log.warning(f'cannot run {self.__class__.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') - return - - # get min and max if auto - if (self._recon_min is None) or (self._recon_max is None): - if self._recon_min is None: self._recon_min = float(torch.min(batch).cpu()) - if self._recon_max is None: self._recon_max = float(torch.max(batch).cpu()) - log.info(f'auto visualisation min: {self._recon_min} and max: {self._recon_max} obtained from {len(batch)} samples') - - # produce latent cycle grid animation - # TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves - animation, stills = latent_cycle_grid_animation( - vae.decode, zs_mean, zs_logvar, - mode=self.mode, num_frames=self._num_frames, decoder_device=vae.device, tensor_style_channels=False, return_stills=True, - to_uint8=True, recon_min=self._recon_min, recon_max=self._recon_max, - ) - image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4) - - # log video -- none, img, vid, both - wandb_items = {} - if self._wandb_mode in ('img', 'both'): wandb_items[f'{self.mode}_img'] = wandb.Image(image) - if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self.mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'), - wb_log_metrics(trainer.logger, wandb_items) - - # log locally - if self.plt_show: - fig, ax = plt.subplots(1, 1, figsize=(self.plt_block_size*stills.shape[1], self.plt_block_size*stills.shape[0])) - ax.imshow(image) - ax.axis('off') - fig.tight_layout() - plt.show() - - -def _normalized_numeric_metrics(items: dict): - results = {} - for k, v in items.items(): - if isinstance(v, (float, int)): - results[k] = v - else: - try: - results[k] = float(v) - except: - log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}') - return results - - -class VaeMetricLoggingCallback(BaseCallbackPeriodic): - - def __init__( - self, - step_end_metrics: Optional[Sequence[str]] = None, - train_end_metrics: Optional[Sequence[str]] = None, - every_n_steps: Optional[int] = None, - begin_first_step: bool = False, - ): - super().__init__(every_n_steps, begin_first_step) - self.step_end_metrics = step_end_metrics if step_end_metrics else [] - self.train_end_metrics = train_end_metrics if train_end_metrics else [] - assert isinstance(self.step_end_metrics, list) - assert isinstance(self.train_end_metrics, list) - assert self.step_end_metrics or self.train_end_metrics, 'No metrics given to step_end_metrics or train_end_metrics' - - def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningModule, metrics: list, is_final=False): - # get dataset and vae framework from trainer and module - dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True) - # check if we need to skip - # TODO: dataset needs to be able to handle wrapped datasets! - if not dataset.is_ground_truth: - warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!') - return - # compute all metrics - for metric in metrics: - pad = max(7+len(k) for k in disent.metrics.DEFAULT_METRICS) # I know this is a magic variable... im just OCD - if is_final: - log.info(f'| {metric.__name__:<{pad}} - computing...') - with Timer() as timer: - scores = metric(dataset, lambda x: vae.encode(x.to(vae.device))) - metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items()) - log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}') - - # log to trainer - prefix = 'final_metric' if is_final else 'epoch_metric' - prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()} - log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores)) - - # log summary for WANDB - # this is kinda hacky... the above should work for parallel coordinate plots - wb_log_reduced_summaries(trainer.logger, prefixed_scores, reduction='max') - - def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - if self.step_end_metrics: - log.debug('Computing Epoch Metrics:') - with Timer() as timer: - self._compute_metrics_and_log(trainer, pl_module, metrics=self.step_end_metrics, is_final=False) - log.debug(f'Computed Epoch Metrics! {timer.pretty}') - - def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - if self.train_end_metrics: - log.debug('Computing Final Metrics...') - with Timer() as timer: - self._compute_metrics_and_log(trainer, pl_module, metrics=self.train_end_metrics, is_final=True) - log.debug(f'Computed Final Metrics! {timer.pretty}') - - -# class VaeLatentCorrelationLoggingCallback(BaseCallbackPeriodic): -# -# def __init__(self, repeats_per_factor=10, every_n_steps=None, begin_first_step=False): -# super().__init__(every_n_steps=every_n_steps, begin_first_step=begin_first_step) -# self._repeats_per_factor = repeats_per_factor -# -# def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): -# # get dataset and vae framework from trainer and module -# dataset, vae = _get_dataset_and_vae(trainer, pl_module) -# # check if we need to skip -# if not dataset.is_ground_truth: -# warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!') -# return -# # TODO: CONVERT THIS TO A METRIC! -# # log the correspondence between factors and the latent space. -# num_samples = np.sum(dataset.ground_truth_data.factor_sizes) * self._repeats_per_factor -# factors = dataset.ground_truth_data.sample_factors(num_samples) -# # encode observations of factors -# zs = np.concatenate([ -# to_numpy(vae.encode(dataset.dataset_batch_from_factors(factor_batch, mode='input').to(vae.device))) -# for factor_batch in iter_chunks(factors, 256) -# ]) -# z_size = zs.shape[-1] -# -# # calculate correlation matrix -# f_and_z = np.concatenate([factors.T, zs.T]) -# f_and_z_corr = np.corrcoef(f_and_z) -# # get correlation submatricies -# f_corr = f_and_z_corr[:z_size, :z_size] # upper left -# z_corr = f_and_z_corr[z_size:, z_size:] # bottom right -# fz_corr = f_and_z_corr[z_size:, :z_size] # upper right | y: z, x: f -# # get maximum z correlations per factor -# z_to_f_corr_maxs = np.max(np.abs(fz_corr), axis=0) -# f_to_z_corr_maxs = np.max(np.abs(fz_corr), axis=1) -# assert len(z_to_f_corr_maxs) == z_size -# assert len(f_to_z_corr_maxs) == dataset.ground_truth_data.num_factors -# # average correlation -# ave_f_to_z_corr = f_to_z_corr_maxs.mean() -# ave_z_to_f_corr = z_to_f_corr_maxs.mean() -# -# # print -# log.info(f'ave latent correlation: {ave_z_to_f_corr}') -# log.info(f'ave factor correlation: {ave_f_to_z_corr}') -# # log everything -# log_metrics(trainer.logger, { -# 'metric.ave_latent_correlation': ave_z_to_f_corr, -# 'metric.ave_factor_correlation': ave_f_to_z_corr, -# }) -# # make sure we only log the heatmap to WandB -# wb_log_metrics(trainer.logger, { -# 'metric.correlation_heatmap': wandb.plots.HeatMap( -# x_labels=[f'z{i}' for i in range(z_size)], -# y_labels=list(dataset.ground_truth_data.factor_names), -# matrix_values=fz_corr, show_text=False -# ), -# }) -# -# NUM = 1 -# # generate traversal value graphs -# for i in range(z_size): -# correlation = np.abs(f_corr[i, :]) -# correlation[i] = 0 -# for j in np.argsort(correlation)[::-1][:NUM]: -# if i == j: -# continue -# ix, iy = (i, j) # if i < j else (j, i) -# plt.scatter(zs[:, ix], zs[:, iy]) -# plt.title(f'z{ix}-vs-z{iy}') -# plt.xlabel(f'z{ix}') -# plt.ylabel(f'z{iy}') -# -# # wandb.log({f"chart.correlation.z{ix}-vs-z{iy}": plt}) -# # make sure we only log to WANDB -# wb_log_metrics(trainer.logger, {f"chart.correlation.z{ix}-vs-max-corr": plt}) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/util/lightning/callbacks/_helper.py b/disent/util/lightning/callbacks/_helper.py new file mode 100644 index 00000000..6d9883af --- /dev/null +++ b/disent/util/lightning/callbacks/_helper.py @@ -0,0 +1,152 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import warnings +from typing import Union + +import pytorch_lightning as pl +from pytorch_lightning.trainer.supporters import CombinedLoader + +from disent.dataset import DisentDataset +from disent.frameworks.ae import Ae +from disent.frameworks.vae import Vae + + +# ========================================================================= # +# Helper Functions # +# ========================================================================= # + + +def _get_dataset_and_ae_like(trainer_or_dataset: pl.Trainer, pl_module: pl.LightningModule, unwrap_groundtruth: bool = False) -> (DisentDataset, Union[Ae, Vae]): + assert isinstance(pl_module, (Ae, Vae)), f'{pl_module.__class__} is not an instance of {Ae} or {Vae}' + # get dataset + if isinstance(trainer_or_dataset, pl.Trainer): + trainer = trainer_or_dataset + if hasattr(trainer, 'datamodule') and (trainer.datamodule is not None): + assert hasattr(trainer.datamodule, 'dataset_train_noaug') # TODO: this is for experiments, another way of handling this should be added + dataset = trainer.datamodule.dataset_train_noaug + elif hasattr(trainer, 'train_dataloader') and (trainer.train_dataloader is not None): + if isinstance(trainer.train_dataloader, CombinedLoader): + dataset = trainer.train_dataloader.loaders.dataset + else: + raise RuntimeError(f'invalid trainer.train_dataloader: {trainer.train_dataloader}') + else: + raise RuntimeError('could not retrieve dataset! please report this...') + else: + dataset = trainer_or_dataset + # check dataset + assert isinstance(dataset, DisentDataset), f'retrieved dataset is not an {DisentDataset.__name__}' + # unwrap dataset + if unwrap_groundtruth: + if dataset.is_wrapped_gt_data: + old_dataset, dataset = dataset, dataset.unwrapped_shallow_copy() + warnings.warn(f'Unwrapped ground truth dataset returned! {type(old_dataset.data).__name__} -> {type(dataset.data).__name__}') + # done checks + return dataset, pl_module + + +# ========================================================================= # +# END # +# ========================================================================= # + + +# class VaeLatentCorrelationLoggingCallback(BaseCallbackPeriodic): +# +# def __init__(self, repeats_per_factor=10, every_n_steps=None, begin_first_step=False): +# super().__init__(every_n_steps=every_n_steps, begin_first_step=begin_first_step) +# self._repeats_per_factor = repeats_per_factor +# +# def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): +# # get dataset and vae framework from trainer and module +# dataset, vae = _get_dataset_and_vae(trainer, pl_module) +# # check if we need to skip +# if not dataset.is_ground_truth: +# warnings.warn(f'{dataset.__class__.__name__} is not an instance of {GroundTruthData.__name__}. Skipping callback: {self.__class__.__name__}!') +# return +# # TODO: CONVERT THIS TO A METRIC! +# # log the correspondence between factors and the latent space. +# num_samples = np.sum(dataset.ground_truth_data.factor_sizes) * self._repeats_per_factor +# factors = dataset.ground_truth_data.sample_factors(num_samples) +# # encode observations of factors +# zs = np.concatenate([ +# to_numpy(vae.encode(dataset.dataset_batch_from_factors(factor_batch, mode='input').to(vae.device))) +# for factor_batch in iter_chunks(factors, 256) +# ]) +# z_size = zs.shape[-1] +# +# # calculate correlation matrix +# f_and_z = np.concatenate([factors.T, zs.T]) +# f_and_z_corr = np.corrcoef(f_and_z) +# # get correlation submatricies +# f_corr = f_and_z_corr[:z_size, :z_size] # upper left +# z_corr = f_and_z_corr[z_size:, z_size:] # bottom right +# fz_corr = f_and_z_corr[z_size:, :z_size] # upper right | y: z, x: f +# # get maximum z correlations per factor +# z_to_f_corr_maxs = np.max(np.abs(fz_corr), axis=0) +# f_to_z_corr_maxs = np.max(np.abs(fz_corr), axis=1) +# assert len(z_to_f_corr_maxs) == z_size +# assert len(f_to_z_corr_maxs) == dataset.ground_truth_data.num_factors +# # average correlation +# ave_f_to_z_corr = f_to_z_corr_maxs.mean() +# ave_z_to_f_corr = z_to_f_corr_maxs.mean() +# +# # print +# log.info(f'ave latent correlation: {ave_z_to_f_corr}') +# log.info(f'ave factor correlation: {ave_f_to_z_corr}') +# # log everything +# log_metrics(trainer.logger, { +# 'metric.ave_latent_correlation': ave_z_to_f_corr, +# 'metric.ave_factor_correlation': ave_f_to_z_corr, +# }) +# # make sure we only log the heatmap to WandB +# wb_log_metrics(trainer.logger, { +# 'metric.correlation_heatmap': wandb.plots.HeatMap( +# x_labels=[f'z{i}' for i in range(z_size)], +# y_labels=list(dataset.ground_truth_data.factor_names), +# matrix_values=fz_corr, show_text=False +# ), +# }) +# +# NUM = 1 +# # generate traversal value graphs +# for i in range(z_size): +# correlation = np.abs(f_corr[i, :]) +# correlation[i] = 0 +# for j in np.argsort(correlation)[::-1][:NUM]: +# if i == j: +# continue +# ix, iy = (i, j) # if i < j else (j, i) +# plt.scatter(zs[:, ix], zs[:, iy]) +# plt.title(f'z{ix}-vs-z{iy}') +# plt.xlabel(f'z{ix}') +# plt.ylabel(f'z{iy}') +# +# # wandb.log({f"chart.correlation.z{ix}-vs-z{iy}": plt}) +# # make sure we only log to WANDB +# wb_log_metrics(trainer.logger, {f"chart.correlation.z{ix}-vs-max-corr": plt}) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/lightning/logger_util.py b/disent/util/lightning/logger_util.py index 0daec37e..c99fcce1 100644 --- a/disent/util/lightning/logger_util.py +++ b/disent/util/lightning/logger_util.py @@ -34,6 +34,11 @@ log = logging.getLogger(__name__) +# TODO: convert these functions into generalised log_img, log_vid, +# log_number, log_text functions that support the different backends. +# - wandb +# - disk backend +# - comet ml # ========================================================================= # # Logger Utils # diff --git a/disent/util/math/dither.py b/disent/util/math/dither.py index 4ae2d524..20229cb1 100644 --- a/disent/util/math/dither.py +++ b/disent/util/math/dither.py @@ -176,6 +176,7 @@ def _is_power_2(num: int): @functools.lru_cache() def _normalize_axis(ndim: int, axis: Optional[Sequence[int]]) -> np.ndarray: # TODO: this functionality may be duplicated + # -- similar to np.normalize_axis_tuple(...) # defaults if axis is None: axis = np.arange(ndim) diff --git a/disent/util/math/integer.py b/disent/util/math/integer.py new file mode 100644 index 00000000..eee92707 --- /dev/null +++ b/disent/util/math/integer.py @@ -0,0 +1,53 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +# ========================================================================= # +# Working With Arbitrary Precision Integers # +# ========================================================================= # + + +def gcd(a: int, b: int) -> int: + """ + Compute the greatest common divisor of a and b + TODO: not actually sure if this returns the correct values for zero or negative inputs? + """ + assert isinstance(a, int), f'number must be an int, got: {type(a)}' + assert isinstance(b, int), f'number must be an int, got: {type(b)}' + while b > 0: + a, b = b, a % b + return a + + +def lcm(a: int, b: int) -> int: + """ + Compute the lowest common multiple of a and b + TODO: not actually sure if this returns the correct values for zero or negative inputs? + """ + return (a * b) // gcd(a, b) + + +# ========================================================================= # +# End # +# ========================================================================= # diff --git a/disent/util/visualize/plot.py b/disent/util/visualize/plot.py index 826594c2..6c9d3de4 100644 --- a/disent/util/visualize/plot.py +++ b/disent/util/visualize/plot.py @@ -44,6 +44,20 @@ log = logging.getLogger(__name__) +# ========================================================================= # +# vars # +# ========================================================================= # + + +_TORCH_NORMAL_TYPES = {torch.float16, torch.float32, torch.float64} + +# torch.complex32 exists in 1.10, but was disabled in 1.11, and planned to be added in 1.12 again +if torch.version.__version__.startswith('1.11.'): + _TORCH_COMPLEX_TYPES = {torch.complex64} +else: + _TORCH_COMPLEX_TYPES = {torch.complex32, torch.complex64} + + # ========================================================================= # # images # # ========================================================================= # @@ -59,11 +73,11 @@ def to_img(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> tor def to_imgs(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> torch.Tensor: # (..., C, H, W) assert x.ndim >= 3, 'image must have 3 or more dimensions: (..., C, H, W)' - assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64}, f'unsupported dtype: {x.dtype}' + assert (x.dtype in _TORCH_NORMAL_TYPES) or (x.dtype in _TORCH_COMPLEX_TYPES), f'unsupported dtype: {x.dtype}' # no gradient with torch.no_grad(): # imaginary to real - if x.dtype in {torch.complex32, torch.complex64}: + if x.dtype in _TORCH_COMPLEX_TYPES: x = torch.abs(x) # scale images if scale: @@ -310,11 +324,10 @@ def visualize_dataset_traversal( # TODO: this is kinda hacky, maybe rather add a check? # TODO: can this be moved into the `output_wandb` if statement? # - animations glitch out if they do not have 3 channels + assert grid.ndim == 5, f'invalid number of dimensions, must be 5, got: {grid.ndim}' if grid.shape[-1] == 1: grid = grid.repeat(3, axis=-1) - - assert grid.ndim == 5 - assert grid.shape[-1] in (1, 3) + assert grid.shape[-1] in (1, 3), f'invalid number of channels, must be 1 or 3, got shape: {grid.shape}. Note that the dataset or augment if specified should output HWC images, not CHW images!' # generate visuals image = make_image_grid(np.concatenate(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=num_frames) diff --git a/disent/util/visualize/vis_img.py b/disent/util/visualize/vis_img.py new file mode 100644 index 00000000..874bd0d3 --- /dev/null +++ b/disent/util/visualize/vis_img.py @@ -0,0 +1,419 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +import warnings +from functools import lru_cache +from numbers import Number +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from numpy.core.multiarray import normalize_axis_index +from PIL import Image + + +# ========================================================================= # +# Type Hints # +# ========================================================================= # + + +# from torch.testing._internal.common_utils import numpy_to_torch_dtype_dict +_NP_TO_TORCH_DTYPE = { + np.dtype('bool'): torch.bool, + np.dtype('uint8'): torch.uint8, + np.dtype('int8'): torch.int8, + np.dtype('int16'): torch.int16, + np.dtype('int32'): torch.int32, + np.dtype('int64'): torch.int64, + np.dtype('float16'): torch.float16, + np.dtype('float32'): torch.float32, + np.dtype('float64'): torch.float64, + np.dtype('complex64'): torch.complex64, + np.dtype('complex128'): torch.complex128 +} + + +MinMaxHint = Union[Number, Tuple[Number, ...], np.ndarray] + + +@lru_cache() +def _dtype_min_max(dtype: torch.dtype) -> Tuple[Union[float, int], Union[float, int]]: + """Get the min and max values for a dtype""" + dinfo = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype) + return dinfo.min, dinfo.max + + +@lru_cache() +def _check_image_dtype(dtype: torch.dtype): + """Check that a dtype can hold image values""" + # check that the datatype is within the right range -- this is not actually necessary if the below is correct! + dmin, dmax = _dtype_min_max(dtype) + imin, imax = (0, 1) if dtype.is_floating_point else (0, 255) + assert (dmin <= imin) and (imax <= dmax), f'The dtype: {repr(dtype)} with range [{dmin}, {dmax}] cannot store image values in the range [{imin}, {imax}]' + # check the datatype is allowed + if dtype not in _ALLOWED_DTYPES: + raise TypeError(f'The dtype: {repr(dtype)} is not allowed, must be one of: {list(_ALLOWED_DTYPES)}') + # return the min and max values + return imin, imax + + +# ========================================================================= # +# Image Helper Functions # +# ========================================================================= # + + +def torch_image_has_valid_range(tensor: torch.Tensor, check_mode: Optional[str] = None) -> bool: + """ + Check that the range of values in the image is correct! + """ + if check_mode not in {'error', 'warn', 'bool', None}: + raise KeyError(f'invalid check_mode: {repr(check_mode)}') + # get the range for the dtype + imin, imax = _check_image_dtype(tensor.dtype) + # get the values + m = tensor.amin().cpu().numpy() + M = tensor.amax().cpu().numpy() + if (m < imin) or (imax < M): + if check_mode == 'error': + raise ValueError(f'images value range: [{m}, {M}] is outside of the required range: [{imin}, {imax}], for dtype: {repr(tensor.dtype)}') + elif check_mode == 'warn': + warnings.warn(f'images value range: [{m}, {M}] is outside of the required range: [{imin}, {imax}], for dtype: {repr(tensor.dtype)}') + return False + return True + + +@torch.no_grad() +def torch_image_clamp(tensor: torch.Tensor, clamp_mode: str = 'warn') -> torch.Tensor: + """ + Clamp the image based on the dtype + Valid `clamp_mode`s are {'warn', 'error', 'clamp'} + """ + # check range of values + if clamp_mode in ('warn', 'error'): + torch_image_has_valid_range(tensor, check_mode=clamp_mode) + elif clamp_mode != 'clamp': + raise KeyError(f'invalid clamp mode: {repr(clamp_mode)}') + # get the range for the dtype + imin, imax = _check_image_dtype(tensor.dtype) + # clamp! + return torch.clamp(tensor, imin, imax) + + +@torch.no_grad() +def torch_image_to_dtype(tensor: torch.Tensor, out_dtype: torch.dtype): + """ + Convert an image to the specified dtype + - Scaling is automatically performed based on the input and output dtype + Floats should be in the range [0, 1], integers should be in the range [0, 255] + - if precision will be lost (), then the values are clamped! + """ + _check_image_dtype(tensor.dtype) + _check_image_dtype(out_dtype) + # check scale + torch_image_has_valid_range(tensor, check_mode='error') + # convert + if tensor.dtype.is_floating_point and (not out_dtype.is_floating_point): + # [float -> int] -- cast after scaling + return torch.clamp(tensor * 255, 0, 255).to(out_dtype) + elif (not tensor.dtype.is_floating_point) and out_dtype.is_floating_point: + # [int -> float] -- cast before scaling + return torch.clamp(tensor.to(out_dtype) / 255, 0, 1) + else: + # [int -> int] | [float -> float] + return tensor.to(out_dtype) + + +@torch.no_grad() +def torch_image_normalize_channels( + tensor: torch.Tensor, + in_min: MinMaxHint, + in_max: MinMaxHint, + channel_dim: int = -1, + out_dtype: Optional[torch.dtype] = None +): + if out_dtype is None: + out_dtype = tensor.dtype + # check dtypes + _check_image_dtype(out_dtype) + assert out_dtype.is_floating_point, f'out_dtype must be a floating point, got: {repr(out_dtype)}' + # get norm values padded to the dimension of the channel + in_min, in_max = _torch_channel_broadcast_scale_values(in_min, in_max, in_dtype=tensor.dtype, dim=channel_dim, ndim=tensor.ndim) + # convert + tensor = tensor.to(out_dtype) + in_min = torch.as_tensor(in_min, dtype=tensor.dtype, device=tensor.device) + in_max = torch.as_tensor(in_max, dtype=tensor.dtype, device=tensor.device) + # warn if the values are the same + if torch.any(in_min == in_max): + m = in_min.cpu().detach().numpy() + M = in_min.cpu().detach().numpy() + warnings.warn(f'minimum: {m} and maximum: {M} values are the same, scaling values to zero.') + # handle equal values + divisor = in_max - in_min + divisor[divisor == 0] = 1 + # normalize + return (tensor - in_min) / divisor + + +# ========================================================================= # +# Argument Helper # +# ========================================================================= # + + +# float16 doesnt always work, rather convert to float32 first +_ALLOWED_DTYPES = { + torch.float32, torch.float64, + torch.uint8, + torch.int, torch.int16, torch.int32, torch.int64, + torch.long, +} + + +@lru_cache() +def _torch_to_images_normalise_args(in_tensor_shape: Tuple[int, ...], in_tensor_dtype: torch.dtype, in_dims: str, out_dims: str, in_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype]): + # check types + if not isinstance(in_dims, str): raise TypeError(f'in_dims must be of type: {str}, but got: {type(in_dims)}') + if not isinstance(out_dims, str): raise TypeError(f'out_dims must be of type: {str}, but got: {type(out_dims)}') + # normalise dim names + in_dims = in_dims.upper() + out_dims = out_dims.upper() + # check dim values + if sorted(in_dims) != sorted('CHW'): raise KeyError(f'in_dims contains the symbols: {repr(in_dims)}, must contain only permutations of: {repr("CHW")}') + if sorted(out_dims) != sorted('CHW'): raise KeyError(f'out_dims contains the symbols: {repr(out_dims)}, must contain only permutations of: {repr("CHW")}') + # get dimension indices + in_c_dim = in_dims.index('C') - len(in_dims) + out_c_dim = out_dims.index('C') - len(out_dims) + transpose_indices = tuple(in_dims.index(c) - len(in_dims) for c in out_dims) + # check image tensor + if len(in_tensor_shape) < 3: + raise ValueError(f'images must have 3 or more dimensions corresponding to: (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}') + if in_tensor_shape[in_c_dim] not in (1, 3): + raise ValueError(f'images do not have the correct number of channels for dim "C", required: 1 or 3. Input format is (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}') + # get default values + if in_dtype is None: in_dtype = in_tensor_dtype + if out_dtype is None: out_dtype = in_dtype + # check dtypes allowed + if in_dtype not in _ALLOWED_DTYPES: raise TypeError(f'in_dtype is not allowed, got: {repr(in_dtype)} must be one of: {list(_ALLOWED_DTYPES)}') + if out_dtype not in _ALLOWED_DTYPES: raise TypeError(f'out_dtype is not allowed, got: {repr(out_dtype)} must be one of: {list(_ALLOWED_DTYPES)}') + # done! + return transpose_indices, in_dtype, out_dtype, out_c_dim + + +def _torch_channel_broadcast_scale_values( + in_min: MinMaxHint, + in_max: MinMaxHint, + in_dtype: torch.dtype, + dim: int, + ndim: int, +) -> Tuple[List[Number], List[Number]]: + return __torch_channel_broadcast_scale_values( + in_min=tuple(np.array(in_min).reshape(-1).tolist()), # TODO: this is slow? + in_max=tuple(np.array(in_max).reshape(-1).tolist()), # TODO: this is slow? + in_dtype=in_dtype, + dim=dim, + ndim=ndim, + ) + +@lru_cache() +@torch.no_grad() +def __torch_channel_broadcast_scale_values( + in_min: MinMaxHint, + in_max: MinMaxHint, + in_dtype: torch.dtype, + dim: int, + ndim: int, +) -> Tuple[List[Number], List[Number]]: + # get the default values + in_min: np.ndarray = np.array((0.0 if in_dtype.is_floating_point else 0.0) if (in_min is None) else in_min) + in_max: np.ndarray = np.array((1.0 if in_dtype.is_floating_point else 255.0) if (in_max is None) else in_max) + # add missing axes + if in_min.ndim == 0: in_min = in_min[None] + if in_max.ndim == 0: in_max = in_max[None] + # checks + assert in_min.ndim == 1 + assert in_max.ndim == 1 + assert np.all(in_min <= in_max), f'min values are not <= the max values: {in_min} !<= {in_max}' + # normalize dim + dim = normalize_axis_index(dim, ndim=ndim) + # pad dim + r_pad = ndim - (dim + 1) + if r_pad > 0: + in_min = in_min[(...,) + ((None,)*r_pad)] + in_max = in_max[(...,) + ((None,)*r_pad)] + # done! + return in_min.tolist(), in_max.tolist() + + +# ========================================================================= # +# Image Conversion # +# ========================================================================= # + + +@torch.no_grad() +def torch_to_images( + tensor: torch.Tensor, + in_dims: str = 'CHW', # we always treat numpy by default as HWC, and torch.Tensor as CHW + out_dims: str = 'HWC', + in_dtype: Optional[torch.dtype] = None, + out_dtype: Optional[torch.dtype] = torch.uint8, + clamp_mode: str = 'warn', # clamp, warn, error + always_rgb: bool = False, + in_min: Optional[MinMaxHint] = None, + in_max: Optional[MinMaxHint] = None, + to_numpy: bool = False, +) -> Union[torch.Tensor, np.ndarray]: + """ + Convert a batch of image-like tensors to images. + A batch in this case consists of an arbitrary number of dimensions of a tensor, + with the last 3 dimensions making up the actual images. + + Process: + 1. check input dtype + 2. move axis + 3. normalize + 4. clamp values + 5. auto scale and convert + 6. convert to rgb + 7. check output dtype + + example: + Convert a tensor of non-normalised images (..., C, H, W) to a + tensor of normalised and clipped images (..., H, W, C). + - integer dtypes are expected to be in the range [0, 255] + - float dtypes are expected to be in the range [0, 1] + + # TODO: add support for uneven in/out dims, eg. in_dims="HW", out_dims="HWC" + """ + # 0.a. check tensor + if not isinstance(tensor, torch.Tensor): + raise TypeError(f'images must be of type: {torch.Tensor}, got: {type(tensor)}') + # 0.b. get arguments + transpose_indices, in_dtype, out_dtype, out_c_dim = _torch_to_images_normalise_args( + in_tensor_shape=tuple(tensor.shape), + in_tensor_dtype=tensor.dtype, + in_dims=in_dims, + out_dims=out_dims, + in_dtype=in_dtype, + out_dtype=out_dtype, + ) + # 1. check input dtype + if in_dtype != tensor.dtype: + raise TypeError(f'images dtype: {repr(tensor.dtype)} does not match in_dtype: {repr(in_dtype)}') + # 2. move axes + tensor = tensor.permute(*(i-tensor.ndim for i in range(tensor.ndim-3)), *transpose_indices) + # 3. normalise + if (in_min is not None) or (in_max is not None): + norm_dtype = (out_dtype if out_dtype.is_floating_point else torch.float32) + tensor = torch_image_normalize_channels(tensor, in_min=in_min, in_max=in_max, channel_dim=out_c_dim, out_dtype=norm_dtype) + # 4. clamp + tensor = torch_image_clamp(tensor, clamp_mode=clamp_mode) + # 5. auto scale and convert + tensor = torch_image_to_dtype(tensor, out_dtype=out_dtype) + # 6. convert to rgb + if always_rgb: + if tensor.shape[out_c_dim] == 1: + tensor = torch.repeat_interleave(tensor, 3, dim=out_c_dim) # torch.repeat is like np.tile, torch.repeat_interleave is like np.repeat + # 7. check output dtype + if out_dtype != tensor.dtype: + raise RuntimeError(f'[THIS IS A BUG!]: After conversion, images tensor dtype: {repr(tensor.dtype)} does not match out_dtype: {repr(in_dtype)}') + if torch.any(torch.isnan(tensor)): + raise RuntimeError('[THIS IS A BUG!]: After conversion, images contain NaN values!') + # convert to numpy + if to_numpy: + return tensor.detach().cpu().numpy() + return tensor + + +def numpy_to_images( + ndarray: np.ndarray, + in_dims: str = 'HWC', # we always treat numpy by default as HWC, and torch.Tensor as CHW + out_dims: str = 'HWC', + in_dtype: Optional[Union[str, np.dtype]] = None, + out_dtype: Optional[Union[str, np.dtype]] = np.dtype('uint8'), + clamp_mode: str = 'warn', # clamp, warn, error + always_rgb: bool = False, + in_min: Optional[MinMaxHint] = None, + in_max: Optional[MinMaxHint] = None, +) -> np.ndarray: + """ + Convert a batch of image-like arrays to images. + A batch in this case consists of an arbitrary number of dimensions of an array, + with the last 3 dimensions making up the actual images. + - See the docs for: torch_to_images(...) + """ + # convert numpy dtypes to torch + if in_dtype is not None: in_dtype = _NP_TO_TORCH_DTYPE[np.dtype(in_dtype)] + if out_dtype is not None: out_dtype = _NP_TO_TORCH_DTYPE[np.dtype(out_dtype)] + # convert back + array = torch_to_images( + tensor=torch.from_numpy(ndarray), + in_dims=in_dims, + out_dims=out_dims, + in_dtype=in_dtype, + out_dtype=out_dtype, + clamp_mode=clamp_mode, + always_rgb=always_rgb, + in_min=in_min, + in_max=in_max, + to_numpy=True, + ) + # done! + return array + + +def numpy_to_pil_images( + ndarray: np.ndarray, + in_dims: str = 'HWC', # we always treat numpy by default as HWC, and torch.Tensor as CHW + clamp_mode: str = 'warn', + always_rgb: bool = False, + in_min: Optional[MinMaxHint] = None, + in_max: Optional[MinMaxHint] = None, +) -> Union[np.ndarray]: + """ + Convert a numpy array containing images (..., C, H, W) to an array of PIL images (...,) + """ + imgs = numpy_to_images( + ndarray=ndarray, + in_dims=in_dims, + out_dims='HWC', + in_dtype=None, + out_dtype='uint8', + clamp_mode=clamp_mode, + always_rgb=always_rgb, + in_min=in_min, + in_max=in_max, + ) + # all the cases (even ndim == 3)... bravo numpy, bravo! + images = [Image.fromarray(imgs[idx]) for idx in np.ndindex(imgs.shape[:-3])] + images = np.array(images, dtype=object).reshape(imgs.shape[:-3]) + # done + return images + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/visualize/vis_model.py b/disent/util/visualize/vis_latents.py similarity index 69% rename from disent/util/visualize/vis_model.py rename to disent/util/visualize/vis_latents.py index b7735ce9..72b88d25 100644 --- a/disent/util/visualize/vis_model.py +++ b/disent/util/visualize/vis_latents.py @@ -23,13 +23,13 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Callable + import numpy as np import torch from disent.util import to_numpy from disent.util.visualize import vis_util -from disent.util.visualize.vis_util import make_animated_image_grid -from disent.util.visualize.vis_util import reconstructions_to_images log = logging.getLogger(__name__) @@ -46,6 +46,7 @@ # CHANGES: # # - extracted from original code # # - was not split into functions in this was # +# - TODO: convert these functions to torch.Tensors # # ========================================================================= # @@ -98,56 +99,56 @@ def _z_minmax_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames): } +def make_latent_zs_cycle( + base_z: torch.Tensor, + z_means: torch.Tensor, + z_logvars: torch.Tensor, + z_idx: int, + num_frames: int, + mode: str = 'minmax_interval_cycle', +) -> torch.Tensor: + # get mode + if mode not in _LATENT_CYCLE_MODES_MAP: + raise KeyError(f'Unsupported mode: {repr(mode)} not in {set(_LATENT_CYCLE_MODES_MAP)}') + z_gen_func = _LATENT_CYCLE_MODES_MAP[mode] + # checks + assert base_z.ndim == 1 + assert base_z.shape == z_means.shape[1:] + assert z_means.ndim == z_logvars.ndim == 2 + assert z_means.shape == z_logvars.shape + assert len(z_means) > 1, f'not enough representations to average, number of z_means should be greater than 1, got: {z_means.shape}' + # make cycle + z_cycle = z_gen_func(to_numpy(base_z), to_numpy(z_means), to_numpy(z_logvars), z_idx, num_frames) + return torch.from_numpy(z_cycle) + + # ========================================================================= # # Visualise Latent Cycles # # ========================================================================= # -# TODO: this function should not convert output to images, it should just be -# left as is. That way we don't need to pass in the recon_min and recon_max -def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_animations=4, num_frames=20, decoder_device=None, recon_min=0., recon_max=1.) -> np.ndarray: - assert len(z_means) > 1 and len(z_logvars) > 1, 'not enough samples to average' - # convert - z_means, z_logvars = to_numpy(z_means), to_numpy(z_logvars) - # get mode - if mode not in _LATENT_CYCLE_MODES_MAP: - raise KeyError(f'Unsupported mode: {repr(mode)} not in {set(_LATENT_CYCLE_MODES_MAP)}') - z_gen_func = _LATENT_CYCLE_MODES_MAP[mode] +# TODO: this should be moved into the VAE and AE classes +def make_decoded_latent_cycles( + decoder_func: Callable[[torch.Tensor], torch.Tensor], + z_means: torch.Tensor, + z_logvars: torch.Tensor, + mode: str = 'minmax_interval_cycle', + num_animations: int = 4, + num_frames: int = 20, + decoder_device=None, +) -> torch.Tensor: + # generate multiple latent traversal visualisations animations = [] - for i, base_z in enumerate(z_means[:num_animations]): + for i in range(num_animations): frames = [] - for j in range(z_means.shape[1]): - z = z_gen_func(base_z, z_means, z_logvars, j, num_frames) + for z_idx in range(z_means.shape[1]): + z = make_latent_zs_cycle(z_means[i], z_means, z_logvars, z_idx, num_frames, mode=mode) z = torch.as_tensor(z, device=decoder_device) frames.append(decoder_func(z)) - animations.append(frames) - return reconstructions_to_images(animations, recon_min=recon_min, recon_max=recon_max) - - -def latent_cycle_grid_animation(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_frames=21, pad=4, border=True, bg_color=0.5, decoder_device=None, tensor_style_channels=True, always_rgb=True, return_stills=False, to_uint8=False, recon_min=0., recon_max=1.) -> np.ndarray: - # produce latent cycle animation & merge frames - stills = latent_cycle(decoder_func, z_means, z_logvars, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=decoder_device, recon_min=recon_min, recon_max=recon_max)[0] - # check and add missing channel if needed (convert greyscale to rgb images) - if always_rgb: - assert stills.shape[-1] in {1, 3}, f'Invalid number of image channels: {stills.shape} ({stills.shape[-1]})' - if stills.shape[-1] == 1: - stills = np.repeat(stills, 3, axis=-1) - # create animation - frames = make_animated_image_grid(stills, pad=pad, border=border, bg_color=bg_color) - # move channels to end - if tensor_style_channels: - if return_stills: - stills = np.transpose(stills, [0, 1, 4, 2, 3]) - frames = np.transpose(frames, [0, 3, 1, 2]) - # convert to uint8 - if to_uint8: - if return_stills: - stills = np.clip(stills*255, 0, 255).astype('uint8') - frames = np.clip(frames*255, 0, 255).astype('uint8') - # done! - if return_stills: - return frames, stills - return frames + animations.append(torch.stack(frames, dim=0)) + animations = torch.stack(animations, dim=0) + # return everything + return animations # (num_animations, z_size, num_frames, C, H, W) # ========================================================================= # diff --git a/disent/util/visualize/vis_util.py b/disent/util/visualize/vis_util.py index 7d46abe2..0385481a 100644 --- a/disent/util/visualize/vis_util.py +++ b/disent/util/visualize/vis_util.py @@ -24,7 +24,11 @@ import logging import warnings +from functools import lru_cache from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Union import numpy as np @@ -52,7 +56,7 @@ } -def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None): +def make_image_grid(images: Sequence[np.ndarray], pad: int = 8, border: bool = True, bg_color=None, num_cols: Optional[int] = None): """ Convert a list of images into a single image that is a grid of those images. :param images: list of input images, all the same size: (I, H, W, C) or (I, H, W) @@ -65,12 +69,12 @@ def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None): # first, second, third channels are the (H, W, C) # get image sizes img_shape, ndim = np.array(images[0].shape), images[0].ndim - assert ndim == 2 or ndim == 3, 'images have wrong number of channels' + assert ndim == 2 or ndim == 3, f'images have wrong number of channels: {img_shape}' assert np.all(img_shape == img.shape for img in images), 'Images are not the same shape!' # get image size and channels img_size = img_shape[:2] if ndim == 3: - assert (img_shape[2] == 1) or (img_shape[2] == 3), 'Invalid number of channels for an image.' + assert (img_shape[2] == 1) or (img_shape[2] == 3), f'Invalid number of channels for an image: {img_shape}' # get bg color if bg_color is None: bg_color = _BG_COLOR_DTYPE_MAP[images[0].dtype] @@ -90,7 +94,7 @@ def make_image_grid(images, pad=8, border=True, bg_color=None, num_cols=None): return grid -def make_animated_image_grid(list_of_animated_images, pad=8, border=True, bg_color=None, num_cols=None): +def make_animated_image_grid(list_of_animated_images: Sequence[np.ndarray], pad: int = 8, border: bool = True, bg_color=None, num_cols: Optional[int] = None): """ :param list_of_animated_images: list of input images, with the second dimension the number of frames: : (I, F, H, W, C) or (I, F, H, W) :param pad: the number of pixels between images @@ -114,7 +118,7 @@ def make_animated_image_grid(list_of_animated_images, pad=8, border=True, bg_col # ========================================================================= # -def _get_grid_size(n, num_cols=None): +def _get_grid_size(n: int, num_cols: Optional[int] = None): """ Determine the number of rows and columns, given the total number of elements n. - if num_cols is None: rows x cols is as square as possible @@ -135,7 +139,7 @@ def _get_grid_size(n, num_cols=None): # ========================================================================= # -def _get_interval_factor_traversal(factor_size, num_frames, start_index=0): +def _get_interval_factor_traversal(factor_size: int, num_frames: int, start_index: int = 0): """ Cycles through the state space in a single cycle. eg. num_indices=5, num_frames=7 returns: [0,1,1,2,3,3,4] @@ -147,29 +151,51 @@ def _get_interval_factor_traversal(factor_size, num_frames, start_index=0): return grid -def _get_cycle_factor_traversal(factor_size, num_frames): +def _get_cycle_factor_traversal(factor_size: int, num_frames: int, start_index: int = 0): """ Cycles through the state space in a single cycle. eg. num_indices=5, num_frames=7 returns: [0,1,3,4,3,2,1] eg. num_indices=4, num_frames=7 returns: [0,1,2,3,2,2,0] """ - grid = _get_interval_factor_traversal(factor_size=factor_size, num_frames=num_frames) + assert start_index == 0, f'cycle factor traversal mode only supports start_index==0, got: {repr(start_index)}' + grid = _get_interval_factor_traversal(factor_size=factor_size, num_frames=num_frames, start_index=0) grid = np.concatenate([grid[0::2], grid[1::2][::-1]]) return grid +def _get_cycle_factor_traversal_from_start(factor_size: int, num_frames: int, start_index: int = 0, ends: bool = False): + all_idxs = np.array([ + *range(start_index, factor_size - (1 if ends else 0)), + *reversed(range(0, factor_size)), + *range(1 if ends else 0, start_index), + ]) + selected_idxs = _get_interval_factor_traversal(factor_size=len(all_idxs), num_frames=num_frames, start_index=0) + grid = all_idxs[selected_idxs] + # checks + assert all_idxs[0] == start_index, 'Please report this bug!' + assert grid[0] == start_index, 'Please report this bug!' + return grid + + +def _get_cycle_factor_traversal_from_start_ends(factor_size: int, num_frames: int, start_index: int = 0, ends: bool = True): + return _get_cycle_factor_traversal_from_start(factor_size=factor_size, num_frames=num_frames, start_index=start_index, ends=ends) + + + _FACTOR_TRAVERSALS = { 'interval': _get_interval_factor_traversal, 'cycle': _get_cycle_factor_traversal, + 'cycle_from_start': _get_cycle_factor_traversal_from_start, + 'cycle_from_start_ends': _get_cycle_factor_traversal_from_start_ends, } -def get_idx_traversal(factor_size, num_frames, mode='interval'): +def get_idx_traversal(factor_size: int, num_frames: int, mode: str = 'interval', start_index: int = 0): try: traversal_fn = _FACTOR_TRAVERSALS[mode] except KeyError: raise KeyError(f'Invalid factor traversal mode: {repr(mode)}') - return traversal_fn(factor_size=factor_size, num_frames=num_frames) + return traversal_fn(factor_size=factor_size, num_frames=num_frames, start_index=start_index) # ========================================================================= # @@ -213,295 +239,5 @@ def cycle_interval(starting_value, num_frames, min_val, max_val): # ========================================================================= # -# Conversion/Util # -# ========================================================================= # - - -# TODO: this functionality is duplicated elsewhere! -# TODO: similar functions exist: output_image, to_img, to_imgs, reconstructions_to_images -def reconstructions_to_images( - recon, - mode: str = 'float', - moveaxis: bool = True, - recon_min: Union[float, List[float]] = 0.0, - recon_max: Union[float, List[float]] = 1.0, - warn_if_clipped: bool = True, -) -> Union[np.ndarray, Image.Image]: - """ - Convert a batch of reconstructions to images. - A batch in this case consists of an arbitrary number of dimensions of an array, - with the last 3 dimensions making up the actual image. For example: (..., channels, size, size) - - NOTE: This function might not be efficient for large amounts of - data due to assertions and initial recursive conversions to a numpy array. - - NOTE: kornia has a similar function! - """ - img = to_numpy(recon) - # checks - assert img.ndim >= 3 - assert img.dtype in (np.float32, np.float64) - # move channels axis - if moveaxis: - img = np.moveaxis(img, -3, -1) - # check min and max - recon_min = np.array(recon_min) - recon_max = np.array(recon_max) - assert recon_min.shape == recon_max.shape - assert recon_min.ndim in (0, 1) # supports channels or glbal min . max - # scale image - img = (img - recon_min) / (recon_max - recon_min) - # check image bounds - if warn_if_clipped: - m, M = np.min(img), np.max(img) - if m < 0 or M > 1: - log.warning(f'images with range [{m}, {M}] have been clipped to the range [0, 1]') - # do clipping - img = np.clip(img, 0, 1) - # convert - if mode == 'float': - return img - elif mode == 'int': - return np.uint8(img * 255) - elif mode == 'pil': - img = np.uint8(img * 255) - # all the cases (even ndim == 3)... bravo numpy, bravo! - images = [Image.fromarray(img[idx]) for idx in np.ndindex(img.shape[:-3])] - images = np.array(images, dtype=object).reshape(img.shape[:-3]) - return images - else: - raise KeyError(f'Invalid mode: {repr(mode)} not in { {"float", "int", "pil"} }') - - -# ========================================================================= # -# Image Util # -# ========================================================================= # - - -# TODO: something like this should replace reconstructions_to_images above! -# def torch_image_clamp(tensor: torch.Tensor, clamp_mode='warn') -> torch.Tensor: -# # get dtype max value -# dtype_M = 1 if tensor.dtype.is_floating_point else 255 -# # handle different modes -# if clamp_mode in ('warn', 'error'): -# m, M = tensor.min().cpu().numpy(), tensor.max().cpu().numpy() -# if (0 < m) or (M > dtype_M): -# if clamp_mode == 'warn': -# warnings.warn(f'image values are out of bounds, expected values in the range: [0, {dtype_M}], received values in the range: {[m, M]}') -# else: -# raise ValueError(f'image values are out of bounds, expected values in the range: [0, {dtype_M}], received values in the range: {[m, M]}') -# elif clamp_mode != 'clamp': -# raise KeyError(f'invalid clamp mode: {repr(clamp_mode)}') -# # clamp values -# return torch.clamp(tensor, 0, dtype_M) -# -# -# # float16 doesnt always work, rather convert to float32 first -# _ALLOWED_DTYPES = { -# torch.float32, torch.float64, -# torch.uint8, -# torch.int, torch.int16, torch.int32, torch.int64, -# torch.long, -# } -# -# -# @lru_cache() -# def _torch_to_images_normalise_args(in_tensor_shape: Tuple[int, ...], in_tensor_dtype: torch.dtype, in_dims: str, out_dims: str, in_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype]): -# # check types -# if not isinstance(in_dims, str): raise TypeError(f'in_dims must be of type: {str}, but got: {type(in_dims)}') -# if not isinstance(out_dims, str): raise TypeError(f'out_dims must be of type: {str}, but got: {type(out_dims)}') -# # normalise dim names -# in_dims = in_dims.upper() -# out_dims = out_dims.upper() -# # check dim values -# if sorted(in_dims) != sorted('CHW'): raise KeyError(f'in_dims contains the symbols: {repr(in_dims)}, must contain only permutations of: {repr("CHW")}') -# if sorted(out_dims) != sorted('CHW'): raise KeyError(f'out_dims contains the symbols: {repr(out_dims)}, must contain only permutations of: {repr("CHW")}') -# # get dimension indices -# in_c_dim = in_dims.index('C') - len(in_dims) -# transpose_indices = tuple(in_dims.index(c) - len(in_dims) for c in out_dims) -# # check image tensor -# if len(in_tensor_shape) < 3: -# raise ValueError(f'images must have 3 or more dimensions corresponding to: (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}') -# if in_tensor_shape[in_c_dim] not in (1, 3): -# raise ValueError(f'images do not have the correct number of channels for dim "C", required: 1 or 3. Input format is (..., {", ".join(in_dims)}), but got shape: {in_tensor_shape}') -# # get default values -# if in_dtype is None: in_dtype = in_tensor_dtype -# if out_dtype is None: out_dtype = in_dtype -# # check dtypes allowed -# if in_dtype not in _ALLOWED_DTYPES: raise TypeError(f'in_dtype is not allowed, got: {repr(in_dtype)} must be one of: {list(_ALLOWED_DTYPES)}') -# if out_dtype not in _ALLOWED_DTYPES: raise TypeError(f'out_dtype is not allowed, got: {repr(out_dtype)} must be one of: {list(_ALLOWED_DTYPES)}') -# # done! -# return transpose_indices, in_dtype, out_dtype -# -# -# def torch_to_images( -# tensor: torch.Tensor, -# in_dims: str = 'CHW', -# out_dims: str = 'HWC', -# in_dtype: Optional[torch.dtype] = None, -# out_dtype: Optional[torch.dtype] = torch.uint8, -# clamp_mode: str = 'warn', # clamp, warn, error -# ) -> torch.Tensor: -# """ -# Convert a batch of image-like tensors to images. -# A batch in this case consists of an arbitrary number of dimensions of a tensor, -# with the last 3 dimensions making up the actual images. -# -# example: -# Convert a tensor of non-normalised images (..., C, H, W) to a -# tensor of normalised and clipped images (..., H, W, C). -# - integer dtypes are expected to be in the range [0, 255] -# - float dtypes are expected to be in the range [0, 1] -# """ -# if not isinstance(tensor, torch.Tensor): -# raise TypeError(f'images must be of type: {torch.Tensor}, got: {type(tensor)}') -# # check arguments -# transpose_indices, in_dtype, out_dtype = _torch_to_images_normalise_args( -# in_tensor_shape=tuple(tensor.shape), in_tensor_dtype=tensor.dtype, -# in_dims=in_dims, out_dims=out_dims, -# in_dtype=in_dtype, out_dtype=out_dtype, -# ) -# # check inputs -# if in_dtype != tensor.dtype: -# raise TypeError(f'images dtype: {repr(tensor.dtype)} does not match in_dtype: {repr(in_dtype)}') -# # convert images -# with torch.no_grad(): -# # check that input values are in the correct range -# # move axes -# tensor = tensor.permute(*(i-tensor.ndim for i in range(tensor.ndim-3)), *transpose_indices) -# # convert outputs -# if in_dtype != out_dtype: -# if in_dtype.is_floating_point and (not out_dtype.is_floating_point): -# tensor = (tensor * 255).to(out_dtype) -# elif (not in_dtype.is_floating_point) and out_dtype.is_floating_point: -# tensor = tensor.to(out_dtype) / 255 -# else: -# tensor = tensor.to(out_dtype) -# # clamp -# tensor = torch_image_clamp(tensor, clamp_mode=clamp_mode) -# # check outputs -# if out_dtype != tensor.dtype: # pragma: no cover -# raise RuntimeError(f'[THIS IS A BUG! PLEASE REPORT THIS!]: After conversion, images tensor dtype: {repr(tensor.dtype)} does not match out_dtype: {repr(in_dtype)}') -# # done -# return tensor -# -# -# def numpy_to_images( -# ndarray: np.ndarray, -# in_dims: str = 'CHW', -# out_dims: str = 'HWC', -# in_dtype: Optional[str, np.dtype] = None, -# out_dtype: Optional[str, np.dtype] = np.dtype('uint8'), -# clamp_mode: str = 'warn', # clamp, warn, error -# ) -> np.ndarray: -# """ -# Convert a batch of image-like arrays to images. -# A batch in this case consists of an arbitrary number of dimensions of an array, -# with the last 3 dimensions making up the actual images. -# - See the docs for: torch_to_imgs(...) -# """ -# # convert numpy dtypes to torch -# if in_dtype is not None: in_dtype = getattr(torch, np.dtype(in_dtype).name) -# if out_dtype is not None: out_dtype = getattr(torch, np.dtype(out_dtype).name) -# # convert back -# tensor = torch_to_images(tensor=torch.from_numpy(ndarray), in_dims=in_dims, out_dims=out_dims, in_dtype=in_dtype, out_dtype=out_dtype, clamp_mode=clamp_mode) -# # done! -# return tensor.numpy() -# -# -# def numpy_to_pil_images(ndarray: np.ndarray, in_dims: str = 'CHW', clamp_mode: str = 'warn'): -# """ -# Convert a numpy array containing images (..., C, H, W) to an array of PIL images (...,) -# """ -# imgs = numpy_to_images(ndarray=ndarray, in_dims=in_dims, out_dims='HWC', in_dtype=None, out_dtype='uint8', clamp_mode=clamp_mode) -# # all the cases (even ndim == 3)... bravo numpy, bravo! -# images = [Image.fromarray(imgs[idx]) for idx in np.ndindex(imgs.shape[:-3])] -# images = np.array(images, dtype=object).reshape(imgs.shape[:-3]) -# # done -# return images -# -# -# def test_torch_to_imgs(): -# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) -# inp_uint8 = (inp_float * 127 + 63).to(torch.uint8) -# # check runs -# out = torch_to_imgs(inp_float) -# assert out.dtype == torch.uint8 -# out = torch_to_imgs(inp_uint8) -# assert out.dtype == torch.uint8 -# out = torch_to_imgs(inp_float, in_dtype=None, out_dtype=None) -# assert out.dtype == inp_float.dtype -# out = torch_to_imgs(inp_uint8, in_dtype=None, out_dtype=None) -# assert out.dtype == inp_uint8.dtype -# -# -# def test_torch_to_imgs_permutations(): -# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) -# inp_uint8 = (inp_float * 127 + 63).to(torch.uint8) -# -# # general checks -# def check_all(inputs, in_dtype=None): -# float_results, int_results = [], [] -# for out_dtype in _ALLOWED_DTYPES: -# out = torch_to_imgs(inputs, in_dtype=in_dtype, out_dtype=out_dtype) -# (float_results if out_dtype.is_floating_point else int_results).append(torch.stack([ -# out.min().to(torch.float64), out.max().to(torch.float64), out.mean(dtype=torch.float64) -# ])) -# for a, b in zip(float_results[:-1], float_results[1:]): assert torch.allclose(a, b) -# for a, b in zip(int_results[:-1], int_results[1:]): assert torch.allclose(a, b) -# -# # check type permutations -# check_all(inp_float, torch.float32) -# check_all(inp_uint8, torch.uint8) -# -# -# def test_torch_to_imgs_preserve_type(): -# for dtype in _ALLOWED_DTYPES: -# tensor = (torch.rand(8, 3, 64, 64) * (1 if dtype.is_floating_point else 255)).to(dtype) -# out = torch_to_imgs(tensor, in_dtype=dtype, out_dtype=dtype, clamp=True) -# assert out.dtype == dtype -# -# -# def test_torch_to_imgs_args(): -# inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) -# -# # check tensor -# with pytest.raises(TypeError, match="images tensor must be of type"): -# torch_to_imgs(tensor=None) -# with pytest.raises(ValueError, match='dim "C", required: 1 or 3'): -# torch_to_imgs(tensor=torch.rand(8, 2, 16, 16, dtype=torch.float32)) -# with pytest.raises(ValueError, match='dim "C", required: 1 or 3'): -# torch_to_imgs(tensor=torch.rand(8, 16, 16, 3, dtype=torch.float32)) -# with pytest.raises(ValueError, match='images tensor must have 3 or more dimensions corresponding to'): -# torch_to_imgs(tensor=torch.rand(16, 16, dtype=torch.float32)) -# -# # check dims -# with pytest.raises(TypeError, match="in_dims must be of type"): -# torch_to_imgs(inp_float, in_dims=None) -# with pytest.raises(TypeError, match="out_dims must be of type"): -# torch_to_imgs(inp_float, out_dims=None) -# with pytest.raises(KeyError, match="in_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"): -# torch_to_imgs(inp_float, in_dims='INVALID') -# with pytest.raises(KeyError, match="out_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"): -# torch_to_imgs(inp_float, out_dims='INVALID') -# with pytest.raises(KeyError, match="in_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"): -# torch_to_imgs(inp_float, in_dims='CHWW') -# with pytest.raises(KeyError, match="out_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"): -# torch_to_imgs(inp_float, out_dims='CHWW') -# -# # check dtypes -# with pytest.raises(TypeError, match="images tensor dtype: torch.float32 does not match in_dtype: torch.uint8"): -# torch_to_imgs(inp_float, in_dtype=torch.uint8) -# with pytest.raises(TypeError, match='in_dtype is not allowed'): -# torch_to_imgs(inp_float, in_dtype=torch.complex64) -# with pytest.raises(TypeError, match='out_dtype is not allowed'): -# torch_to_imgs(inp_float, out_dtype=torch.complex64) -# with pytest.raises(TypeError, match='in_dtype is not allowed'): -# torch_to_imgs(inp_float, in_dtype=torch.float16) -# with pytest.raises(TypeError, match='out_dtype is not allowed'): -# torch_to_imgs(inp_float, out_dtype=torch.float16) - - -# ========================================================================= # -# END # +# END # # ========================================================================= # diff --git a/docs/img/traversals/traversal-transpose__cars3d.jpg b/docs/img/traversals/traversal-transpose__cars3d.jpg new file mode 100644 index 00000000..80bac07e Binary files /dev/null and b/docs/img/traversals/traversal-transpose__cars3d.jpg differ diff --git a/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg b/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg new file mode 100644 index 00000000..7db78492 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites-imagenet-bg-100.jpg differ diff --git a/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg b/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg new file mode 100644 index 00000000..1dae74d1 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites-imagenet-fg-100.jpg differ diff --git a/docs/img/traversals/traversal-transpose__dsprites.jpg b/docs/img/traversals/traversal-transpose__dsprites.jpg new file mode 100644 index 00000000..58a292b7 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__dsprites.jpg differ diff --git a/docs/img/traversals/traversal-transpose__shapes3d.jpg b/docs/img/traversals/traversal-transpose__shapes3d.jpg new file mode 100644 index 00000000..b8cc2a2f Binary files /dev/null and b/docs/img/traversals/traversal-transpose__shapes3d.jpg differ diff --git a/docs/img/traversals/traversal-transpose__smallnorb.jpg b/docs/img/traversals/traversal-transpose__smallnorb.jpg new file mode 100644 index 00000000..80043344 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__smallnorb.jpg differ diff --git a/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg b/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg new file mode 100644 index 00000000..4204acf6 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__xy-object-shaded.jpg differ diff --git a/docs/img/traversals/traversal-transpose__xy-object.jpg b/docs/img/traversals/traversal-transpose__xy-object.jpg new file mode 100644 index 00000000..c50ae6d6 Binary files /dev/null and b/docs/img/traversals/traversal-transpose__xy-object.jpg differ diff --git a/docs/img/xy-object-traversal.png b/docs/img/xy-object-traversal.png deleted file mode 100644 index a26524ac..00000000 Binary files a/docs/img/xy-object-traversal.png and /dev/null differ diff --git a/experiment/config/augment/basic.yaml b/experiment/config/augment/basic.yaml deleted file mode 100644 index 2f44ce41..00000000 --- a/experiment/config/augment/basic.yaml +++ /dev/null @@ -1,43 +0,0 @@ -name: basic - -augment_cls: - _target_: torchvision.transforms.RandomOrder - transforms: - - _target_: kornia.augmentation.ColorJitter - p: 0.5 - brightness: 0.25 - contrast: 0.25 - saturation: 0 - hue: 0.15 - -# THIS IS BUGGY ON BATCH -# - _target_: kornia.augmentation.RandomSharpness -# p: 0.5 -# sharpness: 0.5 - - - _target_: kornia.augmentation.RandomCrop - p: 0.5 - size: [64, 64] - padding: 8 - - - _target_: kornia.augmentation.RandomPerspective - p: 0.5 - distortion_scale: 0.15 - - - _target_: kornia.augmentation.RandomRotation - p: 0.5 - degrees: 9 - -# - _target_: kornia.augmentation.RandomResizedCrop -# p: 0.5 -# size: [64, 64] -# scale: [0.95, 1.05] -# ratio: [0.95, 1.05] - -# THIS REPLACES MOST OF THE ABOVE BUT IT IS BUGGY ON BATCH -# - _target_: kornia.augmentation.RandomAffine -# p: 0.5 -# degrees: 10 -# translate: [0.14, 0.14] -# scale: [0.95, 1.05] -# shear: 5 diff --git a/experiment/config/augment/example.yaml b/experiment/config/augment/example.yaml new file mode 100644 index 00000000..8c4ce29e --- /dev/null +++ b/experiment/config/augment/example.yaml @@ -0,0 +1,8 @@ +name: basic + +augment_cls: + _target_: torchvision.transforms.ColorJitter + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index bb62ad0c..8267d491 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -1,4 +1,5 @@ defaults: + - _self_ # defaults lists override entries from this file! # data - sampling: default__bb - dataset: xyobject @@ -18,8 +19,8 @@ defaults: - run_location: local - run_launcher: local - run_action: train - # entries in this file override entries from default lists - - _self_ + # experiment + - run_plugins: default settings: job: @@ -27,21 +28,20 @@ settings: project: 'DELETE' name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' seed: NULL - framework: beta: 0.0316 recon_loss: mse - loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling` - + loss_reduction: mean # beta scaling framework_opt: latent_distribution: normal # only used by VAEs - model: z_size: 25 weight_init: 'xavier_normal' # xavier_normal, default - dataset: batch_size: 256 - optimizer: lr: 1e-3 +# TODO: https://pytorch-lightning.readthedocs.io/en/stable/common/weights_loading.html +# checkpoint: +# load_checkpoint: NULL # NULL or string +# save_checkpoint: FALSE # boolean, save at end of run -- more advanced checkpointing can be done with a callback! diff --git a/experiment/config/config_test.yaml b/experiment/config/config_test.yaml index 63a66c8a..a24df0c6 100644 --- a/experiment/config/config_test.yaml +++ b/experiment/config/config_test.yaml @@ -1,11 +1,12 @@ defaults: + - _self_ # defaults lists override entries from this file! # data - sampling: default__bb - dataset: xyobject - - augment: none + - augment: example # system - framework: betavae - - model: vae_conv64 + - model: linear # training - optimizer: adam - schedule: beta_cyclic @@ -18,8 +19,8 @@ defaults: - run_location: local_cpu - run_launcher: local - run_action: train - # entries in this file override entries from default lists - - _self_ + # experiment + - run_plugins: default settings: job: @@ -27,21 +28,16 @@ settings: project: 'invalid' name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' seed: NULL - framework: beta: 0.0316 recon_loss: mse - loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling` - + loss_reduction: mean # beta scaling framework_opt: latent_distribution: normal # only used by VAEs - model: z_size: 25 weight_init: 'xavier_normal' # xavier_normal, default - dataset: batch_size: 5 - optimizer: lr: 1e-3 diff --git a/experiment/config/dataset/cars3d.yaml b/experiment/config/dataset/cars3d.yaml index 9cb9055e..99e04de6 100644 --- a/experiment/config/dataset/cars3d.yaml +++ b/experiment/config/dataset/cars3d.yaml @@ -4,13 +4,12 @@ defaults: name: cars3d data: - _target_: disent.dataset.data.Cars3dData + _target_: disent.dataset.data.Cars3d64Data data_root: ${dsettings.storage.data_root} prepare: ${dsettings.dataset.prepare} transform: _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 mean: ${dataset.meta.vis_mean} std: ${dataset.meta.vis_std} diff --git a/experiment/config/dataset/monte_rollouts.yaml b/experiment/config/dataset/monte_rollouts.yaml deleted file mode 100644 index 682eeb54..00000000 --- a/experiment/config/dataset/monte_rollouts.yaml +++ /dev/null @@ -1,21 +0,0 @@ -defaults: - - _data_type_: episodes - -name: monte_rollouts - -data: - _target_: disent.dataset.data.EpisodesDownloadZippedPickledData - required_file: ${dsettings.storage.data_root}/episodes/monte.pkl - download_url: 'https://raw.githubusercontent.com/nmichlo/uploads/main/monte_key.tar.xz' - prepare: ${dsettings.dataset.prepare} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: [64, 64] # slightly squashed? - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] # [3, 210, 160] - vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" - vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" diff --git a/experiment/config/dataset/smallnorb.yaml b/experiment/config/dataset/smallnorb.yaml index 9dfbb8ec..6ffa56fc 100644 --- a/experiment/config/dataset/smallnorb.yaml +++ b/experiment/config/dataset/smallnorb.yaml @@ -4,14 +4,13 @@ defaults: name: smallnorb data: - _target_: disent.dataset.data.SmallNorbData + _target_: disent.dataset.data.SmallNorb64Data data_root: ${dsettings.storage.data_root} prepare: ${dsettings.dataset.prepare} is_test: False transform: _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 mean: ${dataset.meta.vis_mean} std: ${dataset.meta.vis_std} diff --git a/experiment/config/framework/_input_mode_/triple.yaml b/experiment/config/framework/_input_mode_/triplet.yaml similarity index 79% rename from experiment/config/framework/_input_mode_/triple.yaml rename to experiment/config/framework/_input_mode_/triplet.yaml index 40f44980..13d895bb 100644 --- a/experiment/config/framework/_input_mode_/triple.yaml +++ b/experiment/config/framework/_input_mode_/triplet.yaml @@ -1,3 +1,3 @@ # controlled by the framework's defaults list -name: triple +name: triplet num: 3 diff --git a/experiment/config/framework/adagvae_minimal_os.yaml b/experiment/config/framework/adagvae_minimal_os.yaml index e7bcbf13..e5631398 100644 --- a/experiment/config/framework/adagvae_minimal_os.yaml +++ b/experiment/config/framework/adagvae_minimal_os.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} diff --git a/experiment/config/framework/adavae.yaml b/experiment/config/framework/adavae.yaml index d234dcbf..b0c77133 100644 --- a/experiment/config/framework/adavae.yaml +++ b/experiment/config/framework/adavae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} # adavae diff --git a/experiment/config/framework/adavae_os.yaml b/experiment/config/framework/adavae_os.yaml index d8054386..f9225746 100644 --- a/experiment/config/framework/adavae_os.yaml +++ b/experiment/config/framework/adavae_os.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} # adavae diff --git a/experiment/config/framework/ae.yaml b/experiment/config/framework/ae.yaml index cd410b6b..c1fc7f27 100644 --- a/experiment/config/framework/ae.yaml +++ b/experiment/config/framework/ae.yaml @@ -9,7 +9,7 @@ cfg: recon_loss: ${settings.framework.recon_loss} loss_reduction: ${settings.framework.loss_reduction} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE diff --git a/experiment/config/framework/betatcvae.yaml b/experiment/config/framework/betatcvae.yaml index 32005237..ed4ccf73 100644 --- a/experiment/config/framework/betatcvae.yaml +++ b/experiment/config/framework/betatcvae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-TcVae beta: ${settings.framework.beta} diff --git a/experiment/config/framework/betavae.yaml b/experiment/config/framework/betavae.yaml index eb0f1540..031992cb 100644 --- a/experiment/config/framework/betavae.yaml +++ b/experiment/config/framework/betavae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} diff --git a/experiment/config/framework/dfcvae.yaml b/experiment/config/framework/dfcvae.yaml index 9a242f1d..6426366c 100644 --- a/experiment/config/framework/dfcvae.yaml +++ b/experiment/config/framework/dfcvae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} # dfcvae diff --git a/experiment/config/framework/dipvae.yaml b/experiment/config/framework/dipvae.yaml index 4efebcf4..8bd42c05 100644 --- a/experiment/config/framework/dipvae.yaml +++ b/experiment/config/framework/dipvae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} # DIP-VAE diff --git a/experiment/config/framework/infovae.yaml b/experiment/config/framework/infovae.yaml index e9b8234b..3867b2d4 100644 --- a/experiment/config/framework/infovae.yaml +++ b/experiment/config/framework/infovae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Info-VAE # info vae is not based off beta vae, but with # the correct parameter choice this can equal the beta vae diff --git a/experiment/config/framework/tae.yaml b/experiment/config/framework/tae.yaml index 39fffd61..c5279846 100644 --- a/experiment/config/framework/tae.yaml +++ b/experiment/config/framework/tae.yaml @@ -9,7 +9,7 @@ cfg: recon_loss: ${settings.framework.recon_loss} loss_reduction: ${settings.framework.loss_reduction} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE # tvae: triplet stuffs diff --git a/experiment/config/framework/tvae.yaml b/experiment/config/framework/tvae.yaml index 601c52d8..baf9d5a4 100644 --- a/experiment/config/framework/tvae.yaml +++ b/experiment/config/framework/tvae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL # Beta-VAE beta: ${settings.framework.beta} # tvae: triplet stuffs diff --git a/experiment/config/framework/vae.yaml b/experiment/config/framework/vae.yaml index 31ec83c4..4c4d022f 100644 --- a/experiment/config/framework/vae.yaml +++ b/experiment/config/framework/vae.yaml @@ -11,11 +11,10 @@ cfg: # base vae latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components - disable_decoder: FALSE + detach_decoder: FALSE disable_reg_loss: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE - disable_posterior_scale: NULL meta: model_z_multiplier: 2 diff --git a/experiment/config/metrics/standard.yaml b/experiment/config/metrics/standard.yaml deleted file mode 100644 index 49ba02de..00000000 --- a/experiment/config/metrics/standard.yaml +++ /dev/null @@ -1,15 +0,0 @@ -metric_list: - - mig: {} - - sap: {} - - dci: - every_n_steps: 7200 - on_final: TRUE - - factor_vae: - every_n_steps: 7200 - on_final: TRUE - -# these are the default settings, these can be placed in the list above -default_on_final: TRUE -default_on_train: TRUE -default_every_n_steps: 2400 -default_begin_first_step: FALSE diff --git a/experiment/config/model/linear.yaml b/experiment/config/model/linear.yaml index 30e1ed1e..d63408c4 100644 --- a/experiment/config/model/linear.yaml +++ b/experiment/config/model/linear.yaml @@ -1,12 +1,18 @@ name: linear -encoder_cls: - _target_: disent.model.ae.EncoderLinear - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} - z_multiplier: ${framework.meta.model_z_multiplier} - -decoder_cls: - _target_: disent.model.ae.DecoderLinear - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} +model_cls: + # weight initialisation + _target_: disent.nn.weights.init_model_weights + mode: ${settings.model.weight_init} + model: + # auto-encoder + _target_: disent.model.AutoEncoder + encoder: + _target_: disent.model.ae.EncoderLinear + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + decoder: + _target_: disent.model.ae.DecoderLinear + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/model/norm_conv64.yaml b/experiment/config/model/norm_conv64.yaml index 3068055b..29db4b6d 100644 --- a/experiment/config/model/norm_conv64.yaml +++ b/experiment/config/model/norm_conv64.yaml @@ -1,23 +1,29 @@ name: norm_conv64 -encoder_cls: - _target_: disent.model.ae.EncoderConv64Norm - x_shape: ${data.meta.x_shape} - z_size: ${settings.model.z_size} - z_multiplier: ${framework.meta.model_z_multiplier} - activation: ${model.activation} - norm: ${model.norm} - norm_pre_act: ${model.norm_pre_act} +model_cls: + # weight initialisation + _target_: disent.nn.weights.init_model_weights + mode: ${settings.model.weight_init} + model: + # auto-encoder + _target_: disent.model.AutoEncoder + encoder: + _target_: disent.model.ae.EncoderConv64Norm + x_shape: ${data.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + activation: ${model.meta.activation} + norm: ${model.meta.norm} + norm_pre_act: ${model.meta.norm_pre_act} + decoder: + _target_: disent.model.ae.DecoderConv64Norm + x_shape: ${data.meta.x_shape} + z_size: ${settings.model.z_size} + activation: ${model.meta.activation} + norm: ${model.meta.norm} + norm_pre_act: ${model.meta.norm_pre_act} -decoder_cls: - _target_: disent.model.ae.DecoderConv64Norm - x_shape: ${data.meta.x_shape} - z_size: ${settings.model.z_size} - activation: ${model.activation} - norm: ${model.norm} - norm_pre_act: ${model.norm_pre_act} - -# vars -activation: swish # leaky_relu, relu -norm: layer # batch, instance, layer, layer_chn, none -norm_pre_act: TRUE +meta: + activation: swish # leaky_relu, relu + norm: layer # batch, instance, layer, layer_chn, none + norm_pre_act: TRUE diff --git a/experiment/config/model/vae_conv64.yaml b/experiment/config/model/vae_conv64.yaml index a05d00c0..b5e674e9 100644 --- a/experiment/config/model/vae_conv64.yaml +++ b/experiment/config/model/vae_conv64.yaml @@ -1,12 +1,18 @@ name: vae_conv64 -encoder_cls: - _target_: disent.model.ae.EncoderConv64 - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} - z_multiplier: ${framework.meta.model_z_multiplier} - -decoder_cls: - _target_: disent.model.ae.DecoderConv64 - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} +model_cls: + # weight initialisation + _target_: disent.nn.weights.init_model_weights + mode: ${settings.model.weight_init} + model: + # auto-encoder + _target_: disent.model.AutoEncoder + encoder: + _target_: disent.model.ae.EncoderConv64 + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + decoder: + _target_: disent.model.ae.DecoderConv64 + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/model/vae_fc.yaml b/experiment/config/model/vae_fc.yaml index 6a68f40d..69c23a63 100644 --- a/experiment/config/model/vae_fc.yaml +++ b/experiment/config/model/vae_fc.yaml @@ -1,12 +1,18 @@ name: vae_fc -encoder_cls: - _target_: disent.model.ae.EncoderFC - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} - z_multiplier: ${framework.meta.model_z_multiplier} - -decoder_cls: - _target_: disent.model.ae.DecoderFC - x_shape: ${dataset.meta.x_shape} - z_size: ${settings.model.z_size} +model_cls: + # weight initialisation + _target_: disent.nn.weights.init_model_weights + mode: ${settings.model.weight_init} + model: + # auto-encoder + _target_: disent.model.AutoEncoder + encoder: + _target_: disent.model.ae.EncoderFC + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + decoder: + _target_: disent.model.ae.DecoderFC + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/run_callbacks/all.yaml b/experiment/config/run_callbacks/all.yaml index fd4c88a6..a8f24320 100644 --- a/experiment/config/run_callbacks/all.yaml +++ b/experiment/config/run_callbacks/all.yaml @@ -1,42 +1,22 @@ # @package _global_ callbacks: + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback seed: 7777 every_n_steps: 3600 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback seed: 7777 every_n_steps: 3600 traversal_repeats: 100 begin_first_step: TRUE - -# correlation: -# repeats_per_factor: 10 -# every_n_steps: 7200 - -# latent_cycle: -# _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback -# seed: 7777 -# every_n_steps: 7200 -# begin_first_step: FALSE -# mode: 'minmax_interval_cycle' # minmax_interval_cycle, fitted_gaussian_cycle -# recon_min: ${dataset.vis_min} -# recon_max: ${dataset.vis_max} -# recon_mean: ${dataset.vis_mean} -# recon_std: ${dataset.vis_std} -# -# gt_dists: -# _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback -# seed: 7777 -# every_n_steps: 7200 -# begin_first_step: TRUE -# traversal_repeats: 100 -# plt_block_size: 1.25 -# plt_show: False -# plt_transpose: False -# log_wandb: TRUE -# batch_size: ${cfg.dataset.batch_size} -# include_factor_dists: TRUE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/test.yaml b/experiment/config/run_callbacks/test.yaml index 255cda8a..3e680425 100644 --- a/experiment/config/run_callbacks/test.yaml +++ b/experiment/config/run_callbacks/test.yaml @@ -1,18 +1,22 @@ # @package _global_ callbacks: + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback seed: 7777 every_n_steps: 3 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' begin_first_step: FALSE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback seed: 7777 every_n_steps: 4 traversal_repeats: 3 begin_first_step: FALSE - -# correlation: -# repeats_per_factor: 3 -# every_n_steps: 5 + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis.yaml b/experiment/config/run_callbacks/vis.yaml index 2087779c..a8f24320 100644 --- a/experiment/config/run_callbacks/vis.yaml +++ b/experiment/config/run_callbacks/vis.yaml @@ -1,14 +1,22 @@ # @package _global_ callbacks: + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback seed: 7777 every_n_steps: 3600 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback seed: 7777 every_n_steps: 3600 traversal_repeats: 100 begin_first_step: TRUE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis_debug.yaml b/experiment/config/run_callbacks/vis_debug.yaml new file mode 100644 index 00000000..3791ae53 --- /dev/null +++ b/experiment/config/run_callbacks/vis_debug.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +callbacks: + + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback + seed: 7777 + every_n_steps: 600 + begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} + + gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback + seed: 7777 + every_n_steps: 600 + traversal_repeats: 50 + begin_first_step: FALSE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis_fast.yaml b/experiment/config/run_callbacks/vis_fast.yaml index 3df24a15..4c1b3802 100644 --- a/experiment/config/run_callbacks/vis_fast.yaml +++ b/experiment/config/run_callbacks/vis_fast.yaml @@ -1,14 +1,22 @@ # @package _global_ callbacks: + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback seed: 7777 every_n_steps: 1800 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback seed: 7777 every_n_steps: 1800 traversal_repeats: 100 begin_first_step: TRUE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis_quick.yaml b/experiment/config/run_callbacks/vis_quick.yaml new file mode 100644 index 00000000..d37d8805 --- /dev/null +++ b/experiment/config/run_callbacks/vis_quick.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +callbacks: + + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback + seed: 7777 + every_n_steps: 600 + begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} + + gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback + seed: 7777 + every_n_steps: 1800 + traversal_repeats: 50 + begin_first_step: FALSE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis_skip_first.yaml b/experiment/config/run_callbacks/vis_skip_first.yaml new file mode 100644 index 00000000..883ffb0d --- /dev/null +++ b/experiment/config/run_callbacks/vis_skip_first.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +callbacks: + + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback + seed: 7777 + every_n_steps: 3600 + begin_first_step: FALSE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} + + gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback + seed: 7777 + every_n_steps: 3600 + traversal_repeats: 100 + begin_first_step: FALSE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_callbacks/vis_slow.yaml b/experiment/config/run_callbacks/vis_slow.yaml index 83f0516b..6dca0c0f 100644 --- a/experiment/config/run_callbacks/vis_slow.yaml +++ b/experiment/config/run_callbacks/vis_slow.yaml @@ -1,14 +1,22 @@ # @package _global_ callbacks: + latent_cycle: + _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback seed: 7777 every_n_steps: 7200 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' begin_first_step: TRUE + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + log_wandb: ${logging.wandb.enabled} + recon_mean: ${dataset.meta.vis_mean} + recon_std: ${dataset.meta.vis_std} gt_dists: + _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback seed: 7777 every_n_steps: 7200 traversal_repeats: 100 begin_first_step: TRUE + log_wandb: ${logging.wandb.enabled} + batch_size: ${settings.dataset.batch_size} diff --git a/experiment/config/run_length/longmed.yaml b/experiment/config/run_length/longmed.yaml new file mode 100644 index 00000000..7ffd07a5 --- /dev/null +++ b/experiment/config/run_length/longmed.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +trainer: + max_epochs: 86400 + max_steps: 86400 diff --git a/experiment/config/run_location/local.yaml b/experiment/config/run_location/local.yaml index cc788508..8c2f249e 100644 --- a/experiment/config/run_location/local.yaml +++ b/experiment/config/run_location/local.yaml @@ -2,22 +2,21 @@ dsettings: trainer: - cuda: NULL # `NULL` tries to use CUDA if it is available. `TRUE` forces cuda to be used! + cuda: NULL # `NULL` tries to use CUDA if it is available, otherwise defaulting to the CPU storage: logs_dir: 'logs' data_root: '/tmp/${oc.env:USER}/datasets' dataset: - gpu_augment: FALSE prepare: TRUE try_in_memory: TRUE -trainer: +datamodule: + gpu_augment: FALSE prepare_data_per_node: TRUE - -dataloader: - num_workers: 8 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} + dataloader: + num_workers: 8 + pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! + batch_size: ${settings.dataset.batch_size} hydra: job: diff --git a/experiment/config/run_location/local_cpu.yaml b/experiment/config/run_location/local_cpu.yaml index d1bcd61a..07ccfb84 100644 --- a/experiment/config/run_location/local_cpu.yaml +++ b/experiment/config/run_location/local_cpu.yaml @@ -2,22 +2,21 @@ dsettings: trainer: - cuda: FALSE + cuda: FALSE # The job will only use the CPU storage: logs_dir: 'logs' data_root: '/tmp/${oc.env:USER}/datasets' dataset: - gpu_augment: FALSE prepare: TRUE try_in_memory: TRUE -trainer: +datamodule: + gpu_augment: FALSE prepare_data_per_node: TRUE - -dataloader: - num_workers: 8 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} + dataloader: + num_workers: 8 + pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! + batch_size: ${settings.dataset.batch_size} hydra: job: diff --git a/experiment/config/run_location/cluster.yaml b/experiment/config/run_location/local_gpu.yaml similarity index 59% rename from experiment/config/run_location/cluster.yaml rename to experiment/config/run_location/local_gpu.yaml index 9194493f..c166c948 100644 --- a/experiment/config/run_location/cluster.yaml +++ b/experiment/config/run_location/local_gpu.yaml @@ -2,26 +2,21 @@ dsettings: trainer: - cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly + cuda: TRUE # `TRUE` forces cuda to be used. The job fails if cuda is not available! storage: logs_dir: 'logs' data_root: '/tmp/${oc.env:USER}/datasets' dataset: - gpu_augment: FALSE prepare: TRUE try_in_memory: TRUE - launcher: - partition: batch - array_parallelism: 16 - exclude: "mscluster93,mscluster94,mscluster97,mscluster99" -trainer: +datamodule: + gpu_augment: FALSE prepare_data_per_node: TRUE - -dataloader: - num_workers: 8 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} + dataloader: + num_workers: 8 + pin_memory: ${dsettings.trainer.cuda} # uses more memory, but faster! + batch_size: ${settings.dataset.batch_size} hydra: job: diff --git a/experiment/config/run_location/stampede_shr.yaml b/experiment/config/run_location/stampede_shr.yaml deleted file mode 100644 index 05ba8626..00000000 --- a/experiment/config/run_location/stampede_shr.yaml +++ /dev/null @@ -1,33 +0,0 @@ -# @package _global_ - -dsettings: - trainer: - cuda: NULL # auto-detect cuda, some nodes are not configured correctly - storage: - logs_dir: 'logs' - data_root: '${oc.env:HOME}/downloads/datasets' # WE NEED TO BE VERY CAREFUL ABOUT USING A SHARED DRIVE - dataset: - gpu_augment: FALSE - prepare: FALSE # WE MUST PREPARE DATA MANUALLY BEFOREHAND - try_in_memory: TRUE - launcher: - partition: stampede - array_parallelism: 16 - exclude: "mscluster93,mscluster94,mscluster97,mscluster99" - -trainer: - prepare_data_per_node: TRUE - -dataloader: - num_workers: 16 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} - -hydra: - job: - name: 'disent' - run: - dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - sweep: - dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_location/stampede_tmp.yaml b/experiment/config/run_location/stampede_tmp.yaml deleted file mode 100644 index d596b335..00000000 --- a/experiment/config/run_location/stampede_tmp.yaml +++ /dev/null @@ -1,33 +0,0 @@ -# @package _global_ - -dsettings: - trainer: - cuda: NULL # auto-detect cuda, some nodes are not configured correctly - storage: - logs_dir: 'logs' - data_root: '/tmp/${oc.env:USER}/datasets' - dataset: - gpu_augment: FALSE - prepare: TRUE - try_in_memory: TRUE - launcher: - partition: stampede - array_parallelism: 16 - exclude: "mscluster93,mscluster94,mscluster97,mscluster99" - -trainer: - prepare_data_per_node: TRUE - -dataloader: - num_workers: 16 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} - -hydra: - job: - name: 'disent' - run: - dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - sweep: - dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_logging/none.yaml b/experiment/config/run_logging/none.yaml index 2a99d3c4..a9c656ff 100644 --- a/experiment/config/run_logging/none.yaml +++ b/experiment/config/run_logging/none.yaml @@ -6,13 +6,14 @@ defaults: trainer: log_every_n_steps: 50 - flush_logs_every_n_steps: 100 - progress_bar_refresh_rate: 0 # disable the builtin progress bar + enable_progress_bar: FALSE # disable the builtin progress bar callbacks: progress: + _target_: disent.util.lightning.callbacks.LoggerProgressCallback interval: 5 logging: wandb: enabled: FALSE + logger: NULL diff --git a/experiment/config/run_logging/wandb.yaml b/experiment/config/run_logging/wandb.yaml index a72cb887..53c7f12c 100644 --- a/experiment/config/run_logging/wandb.yaml +++ b/experiment/config/run_logging/wandb.yaml @@ -6,19 +6,26 @@ defaults: trainer: log_every_n_steps: 100 - flush_logs_every_n_steps: 200 - progress_bar_refresh_rate: 0 # disable the builtin progress bar + enable_progress_bar: FALSE # disable the builtin progress bar callbacks: progress: + _target_: disent.util.lightning.callbacks.LoggerProgressCallback interval: 15 logging: wandb: enabled: TRUE + logger: + _target_: pytorch_lightning.loggers.WandbLogger offline: FALSE - entity: '${settings.job.user}' - project: '${settings.job.project}' - name: '${settings.job.name}' - group: null + entity: ${settings.job.user} + project: ${settings.job.project} + name: ${settings.job.name} + group: NULL tags: [] + save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd + # https://docs.wandb.ai/guides/track/launch#init-start-error + settings: + _target_: wandb.Settings + start_method: "fork" # fork: linux/macos, thread: google colab diff --git a/experiment/config/run_logging/wandb_fast.yaml b/experiment/config/run_logging/wandb_fast.yaml index 75c305ee..f0f978b9 100644 --- a/experiment/config/run_logging/wandb_fast.yaml +++ b/experiment/config/run_logging/wandb_fast.yaml @@ -6,19 +6,26 @@ defaults: trainer: log_every_n_steps: 50 - flush_logs_every_n_steps: 100 - progress_bar_refresh_rate: 0 # disable the builtin progress bar + enable_progress_bar: FALSE # disable the builtin progress bar callbacks: progress: + _target_: disent.util.lightning.callbacks.LoggerProgressCallback interval: 5 logging: wandb: enabled: TRUE + logger: + _target_: pytorch_lightning.loggers.WandbLogger offline: FALSE - entity: '${settings.job.user}' - project: '${settings.job.project}' - name: '${settings.job.name}' - group: null + entity: ${settings.job.user} + project: ${settings.job.project} + name: ${settings.job.name} + group: NULL tags: [] + save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd + # https://docs.wandb.ai/guides/track/launch#init-start-error + settings: + _target_: wandb.Settings + start_method: "fork" # fork: linux/macos, thread: google colab diff --git a/experiment/config/run_logging/wandb_fast_offline.yaml b/experiment/config/run_logging/wandb_fast_offline.yaml index 372e480a..0095a2c7 100644 --- a/experiment/config/run_logging/wandb_fast_offline.yaml +++ b/experiment/config/run_logging/wandb_fast_offline.yaml @@ -6,19 +6,26 @@ defaults: trainer: log_every_n_steps: 50 - flush_logs_every_n_steps: 100 - progress_bar_refresh_rate: 0 # disable the builtin progress bar + enable_progress_bar: FALSE # disable the builtin progress bar callbacks: progress: + _target_: disent.util.lightning.callbacks.LoggerProgressCallback interval: 5 logging: wandb: enabled: TRUE + logger: + _target_: pytorch_lightning.loggers.WandbLogger offline: TRUE - entity: '${settings.job.user}' - project: '${settings.job.project}' - name: '${settings.job.name}' - group: null + entity: ${settings.job.user} + project: ${settings.job.project} + name: ${settings.job.name} + group: NULL tags: [] + save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd + # https://docs.wandb.ai/guides/track/launch#init-start-error + settings: + _target_: wandb.Settings + start_method: "fork" # fork: linux/macos, thread: google colab diff --git a/experiment/config/run_logging/wandb_slow.yaml b/experiment/config/run_logging/wandb_slow.yaml index f7f4a49c..5718572f 100644 --- a/experiment/config/run_logging/wandb_slow.yaml +++ b/experiment/config/run_logging/wandb_slow.yaml @@ -6,19 +6,26 @@ defaults: trainer: log_every_n_steps: 200 - flush_logs_every_n_steps: 400 - progress_bar_refresh_rate: 0 # disable the builtin progress bar + enable_progress_bar: FALSE # disable the builtin progress bar callbacks: progress: + _target_: disent.util.lightning.callbacks.LoggerProgressCallback interval: 30 logging: wandb: enabled: TRUE + logger: + _target_: pytorch_lightning.loggers.WandbLogger offline: FALSE - entity: '${settings.job.user}' - project: '${settings.job.project}' - name: '${settings.job.name}' - group: null + entity: ${settings.job.user} + project: ${settings.job.project} + name: ${settings.job.name} + group: NULL tags: [] + save_dir: ${abspath:${dsettings.storage.logs_dir}} # relative to hydra's original cwd + # https://docs.wandb.ai/guides/track/launch#init-start-error + settings: + _target_: wandb.Settings + start_method: "fork" # fork: linux/macos, thread: google colab diff --git a/experiment/config/run_plugins/default.yaml b/experiment/config/run_plugins/default.yaml new file mode 100644 index 00000000..c608f5f3 --- /dev/null +++ b/experiment/config/run_plugins/default.yaml @@ -0,0 +1,6 @@ +# @package _global_ + +# call the listed functions here before the experiment is started +# - this can be used to register functions or metrics to the disent registry for example! +experiment: + plugins: [] diff --git a/experiment/config/sampling/gt_dist__manhat_scaled.yaml b/experiment/config/sampling/gt_dist__manhat_scaled.yaml index 5fb96993..5de529ab 100644 --- a/experiment/config/sampling/gt_dist__manhat_scaled.yaml +++ b/experiment/config/sampling/gt_dist__manhat_scaled.yaml @@ -4,5 +4,5 @@ defaults: name: gt_dist__manhat_scaled -triplet_sample_mode: "manhattan_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_sample_mode: 'manhattan_scaled' # random, factors, manhattan, manhattan_scaled, combined, combined_scaled triplet_swap_chance: 0 diff --git a/experiment/config/schedule/adavae_down_all.yaml b/experiment/config/schedule/adavae_down_all.yaml deleted file mode 100644 index 8d4ac7f5..00000000 --- a/experiment/config/schedule/adavae_down_all.yaml +++ /dev/null @@ -1,27 +0,0 @@ -name: averaging_decrease__all - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # ada triplet - r_end: 0.0 # triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # loss active - r_end: 0.0 # loss inactive - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.5 # ada weighted triplet - r_end: 1.0 # normal triplet - ada_thresh_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # all averaged, should this not be 0.5 the recommended value - r_end: 0.0 # none averaged diff --git a/experiment/config/schedule/adavae_down_ratio.yaml b/experiment/config/schedule/adavae_down_ratio.yaml deleted file mode 100644 index e3816fe1..00000000 --- a/experiment/config/schedule/adavae_down_ratio.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: averaging_decrease__ratio - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # ada triplet - r_end: 0.0 # triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # loss active - r_end: 0.0 # loss inactive - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.5 # ada weighted triplet - r_end: 1.0 # normal triplet diff --git a/experiment/config/schedule/adavae_down_thresh.yaml b/experiment/config/schedule/adavae_down_thresh.yaml deleted file mode 100644 index 0ca660ac..00000000 --- a/experiment/config/schedule/adavae_down_thresh.yaml +++ /dev/null @@ -1,9 +0,0 @@ -name: averaging_decrease__thresh - -schedule_items: - ada_thresh_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # all averaged, should this not be 0.5 the recommended value - r_end: 0.0 # none averaged diff --git a/experiment/config/schedule/adavae_up_all.yaml b/experiment/config/schedule/adavae_up_all.yaml deleted file mode 100644 index 58cdd98d..00000000 --- a/experiment/config/schedule/adavae_up_all.yaml +++ /dev/null @@ -1,27 +0,0 @@ -name: averaging_increase__all - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # triplet - r_end: 1.0 # ada triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # loss inactive - r_end: 1.0 # loss active - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # normal triplet - r_end: 0.5 # ada weighted triplet - ada_thresh_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # none averaged - r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/config/schedule/adavae_up_all_full.yaml b/experiment/config/schedule/adavae_up_all_full.yaml deleted file mode 100644 index 471a8da4..00000000 --- a/experiment/config/schedule/adavae_up_all_full.yaml +++ /dev/null @@ -1,27 +0,0 @@ -name: averaging_increase__all_full - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # triplet - r_end: 1.0 # ada triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # loss inactive - r_end: 1.0 # loss active - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # normal triplet - r_end: 0.0 # ada weighted triplet - ada_thresh_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # none averaged - r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/config/schedule/adavae_up_ratio.yaml b/experiment/config/schedule/adavae_up_ratio.yaml deleted file mode 100644 index 79e3e6c3..00000000 --- a/experiment/config/schedule/adavae_up_ratio.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: averaging_increase__ratio - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # triplet - r_end: 1.0 # ada triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # loss inactive - r_end: 1.0 # loss active - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # normal triplet - r_end: 0.5 # ada weighted triplet diff --git a/experiment/config/schedule/adavae_up_ratio_full.yaml b/experiment/config/schedule/adavae_up_ratio_full.yaml deleted file mode 100644 index 5b7fabe3..00000000 --- a/experiment/config/schedule/adavae_up_ratio_full.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: averaging_increase__ratio_full - -schedule_items: - adat_triplet_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # triplet - r_end: 1.0 # ada triplet - adat_triplet_soft_scale: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # loss inactive - r_end: 1.0 # loss active - adat_triplet_share_scale: # reversed compared to others - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 1.0 # normal triplet - r_end: 0.0 # ada weighted triplet diff --git a/experiment/config/schedule/adavae_up_thresh.yaml b/experiment/config/schedule/adavae_up_thresh.yaml deleted file mode 100644 index 7c012a32..00000000 --- a/experiment/config/schedule/adavae_up_thresh.yaml +++ /dev/null @@ -1,9 +0,0 @@ -name: averaging_increase__thresh - -schedule_items: - ada_thresh_ratio: - _target_: disent.schedule.LinearSchedule - start_step: 0 - end_step: ${trainer.max_steps} - r_start: 0.0 # none averaged - r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/run.py b/experiment/run.py index bfe57545..c5827093 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -25,6 +25,8 @@ import logging import os from datetime import datetime +from typing import Callable +from typing import Optional import hydra import pytorch_lightning as pl @@ -34,22 +36,19 @@ from omegaconf import DictConfig from omegaconf import ListConfig from omegaconf import OmegaConf -from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning import Callback +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.loggers import LightningLoggerBase -from disent import metrics -from disent.frameworks import DisentConfigurable +import disent.registry as R from disent.frameworks import DisentFramework -from disent.model import AutoEncoder -from disent.nn.weights import init_model_weights +from disent.util.lightning.callbacks import VaeMetricLoggingCallback from disent.util.seeds import seed -from disent.util.strings.fmt import make_box_str from disent.util.strings import colors as c -from disent.util.lightning.callbacks import LoggerProgressCallback -from disent.util.lightning.callbacks import VaeMetricLoggingCallback -from disent.util.lightning.callbacks import VaeLatentCycleLoggingCallback -from disent.util.lightning.callbacks import VaeGtDistsLoggingCallback +from disent.util.strings.fmt import make_box_str + from experiment.util.hydra_data import HydraDataModule -from experiment.util.run_utils import log_error_and_exit +from experiment.util.hydra_main import hydra_main from experiment.util.run_utils import safe_unset_debug_logger from experiment.util.run_utils import safe_unset_debug_trainer from experiment.util.run_utils import set_debug_logger @@ -64,6 +63,17 @@ # ========================================================================= # +def hydra_register_disent_plugins(cfg): + # TODO: there should be a plugin mechanism for disent? + if cfg.experiment.plugins: + log.info('Running experiment plugins:') + for plugin in cfg.experiment.plugins: + log.info(f'* registering: {plugin}') + hydra.utils.instantiate(dict(_target_=plugin)) + else: + log.info('No experiment plugins were listed. Register these under the `experiment.plugins` in the config, which lists targets of functions.') + + def hydra_get_gpus(cfg) -> int: use_cuda = cfg.dsettings.trainer.cuda # check cuda values @@ -85,7 +95,7 @@ def hydra_get_gpus(cfg) -> int: def hydra_check_data_paths(cfg): - prepare_data_per_node = cfg.trainer.prepare_data_per_node + prepare_data_per_node = cfg.datamodule.prepare_data_per_node data_root = cfg.dsettings.storage.data_root # check relative paths if not os.path.isabs(data_root): @@ -98,78 +108,32 @@ def hydra_check_data_paths(cfg): ) if prepare_data_per_node: log.error( - f'trainer.prepare_data_per_node={repr(prepare_data_per_node)} but dsettings.storage.data_root=' + f'datamodule.prepare_data_per_node={repr(prepare_data_per_node)} but dsettings.storage.data_root=' f'{repr(data_root)} is a relative path which may be an error! Try specifying an' f' absolute path that is guaranteed to be unique from each node, eg. default_settings.storage.data_root=/tmp/dataset' ) raise RuntimeError(f'default_settings.storage.data_root={repr(data_root)} is a relative path!') -def hydra_make_logger(cfg): - # make wandb logger - backend = cfg.logging.wandb - if backend.enabled: - log.info('Initialising Weights & Biases Logger') - return WandbLogger( - offline=backend.offline, - entity=backend.entity, # cometml: workspace - project=backend.project, # cometml: project_name - name=backend.name, # cometml: experiment_name - group=backend.group, # experiment group - tags=backend.tags, # experiment tags - save_dir=hydra.utils.to_absolute_path(cfg.dsettings.storage.logs_dir), # relative to hydra's original cwd - ) - # don't return a logger - return None # LoggerCollection([...]) OR DummyLogger(...) - - -def _callback_make_progress(cfg, callback_cfg): - return LoggerProgressCallback( - interval=callback_cfg.interval - ) - - -def _callback_make_latent_cycle(cfg, callback_cfg): - if cfg.logging.wandb.enabled: - # checks - if not (('vis_min' in cfg.dataset and 'vis_max' in cfg.dataset) or ('vis_mean' in cfg.dataset and 'vis_std' in cfg.dataset)): - log.warning('dataset does not have visualisation ranges specified, set `vis_min` & `vis_max` OR `vis_mean` & `vis_std`') - # this currently only supports WANDB logger - return VaeLatentCycleLoggingCallback( - seed = callback_cfg.seed, - every_n_steps = callback_cfg.every_n_steps, - begin_first_step = callback_cfg.begin_first_step, - mode = callback_cfg.mode, - # recon_min = cfg.data.meta.vis_min, - # recon_max = cfg.data.meta.vis_max, - recon_mean = cfg.dataset.meta.vis_mean, - recon_std = cfg.dataset.meta.vis_std, - ) +def hydra_check_data_meta(cfg): + # checks + if (cfg.dataset.meta.vis_mean is None) or (cfg.dataset.meta.vis_std is None): + log.warning(f'Dataset has no normalisation values... Are you sure this is correct?') + log.warning(f'* dataset.meta.vis_mean: {cfg.dataset.meta.vis_mean}') + log.warning(f'* dataset.meta.vis_std: {cfg.dataset.meta.vis_std}') else: - log.warning('latent_cycle callback is not being used because wandb is not enabled!') - return None - - -def _callback_make_gt_dists(cfg, callback_cfg): - return VaeGtDistsLoggingCallback( - seed = callback_cfg.seed, - every_n_steps = callback_cfg.every_n_steps, - traversal_repeats = callback_cfg.traversal_repeats, - begin_first_step = callback_cfg.begin_first_step, - plt_block_size = 1.25, - plt_show = False, - plt_transpose = False, - log_wandb = True, - batch_size = cfg.settings.dataset.batch_size, - include_factor_dists = True, - ) + log.info(f'Dataset has normalisation values!') + log.info(f'* dataset.meta.vis_mean: {cfg.dataset.meta.vis_mean}') + log.info(f'* dataset.meta.vis_std: {cfg.dataset.meta.vis_std}') -_CALLBACK_MAKERS = { - 'progress': _callback_make_progress, - 'latent_cycle': _callback_make_latent_cycle, - 'gt_dists': _callback_make_gt_dists, -} +def hydra_make_logger(cfg) -> Optional[LightningLoggerBase]: + logger = hydra.utils.instantiate(cfg.logging.logger) + if logger: + log.info(f'Initialised Logger: {logger}') + else: + log.warning(f'No Logger Utilised!') + return logger def hydra_get_callbacks(cfg) -> list: @@ -177,21 +141,16 @@ def hydra_get_callbacks(cfg) -> list: # add all callbacks for name, item in cfg.callbacks.items(): # custom callback handling vs instantiation - if '_target_' in item: - name = f'{name} ({item._target_})' - callback = hydra.utils.instantiate(item) - else: - callback = _CALLBACK_MAKERS[name](cfg, item) + callback = hydra.utils.instantiate(item) + assert isinstance(callback, Callback), f'instantiated callback is not an instance of {Callback}, got: {callback}' # add to callbacks list - if callback is not None: - log.info(f'made callback: {name}') - callbacks.append(callback) - else: - log.info(f'skipped callback: {name}') + log.info(f'made callback: {name} ({item._target_})') + callbacks.append(callback) return callbacks def hydra_get_metric_callbacks(cfg) -> list: + # TODO: simplify this, make better use of the config! callbacks = [] # set default values used later default_every_n_steps = cfg.metrics.default_every_n_steps @@ -211,8 +170,8 @@ def hydra_get_metric_callbacks(cfg) -> list: # check values assert isinstance(metric, (dict, DictConfig)), f'settings for entry in metric list is not a dictionary, got type: {type(settings)} or value: {repr(settings)}' # make metrics - train_metric = [metrics.FAST_METRICS[name]] if settings.get('on_train', default_on_train) else None - final_metric = [metrics.DEFAULT_METRICS[name]] if settings.get('on_final', default_on_final) else None + train_metric = [R.METRICS[name].compute_fast] if settings.get('on_train', default_on_train) else None + final_metric = [R.METRICS[name].compute] if settings.get('on_final', default_on_final) else None # add the metric callback if final_metric or train_metric: callbacks.append(VaeMetricLoggingCallback( @@ -224,54 +183,49 @@ def hydra_get_metric_callbacks(cfg) -> list: return callbacks -def hydra_register_schedules(module: DisentFramework, cfg): - # check the type - schedule_items = cfg.schedule.schedule_items - assert isinstance(schedule_items, (dict, DictConfig)), f'`schedule.schedule_items` must be a dictionary, got type: {type(schedule_items)} with value: {repr(schedule_items)}' - # add items - if schedule_items: - log.info(f'Registering Schedules:') - for target, schedule in schedule_items.items(): - module.register_schedule(target, hydra.utils.instantiate(schedule), logging=True) - +def hydra_create_framework(cfg, gpu_batch_augment: Optional[Callable[[torch.Tensor], torch.Tensor]] = None) -> DisentFramework: + # create framework + assert str.endswith(cfg.framework.cfg['_target_'], '.cfg'), f'`cfg.framework.cfg._target_` does not end with ".cfg", got: {repr(cfg.framework.cfg["_target_"])}' + framework_cls = hydra.utils.get_class(cfg.framework.cfg['_target_'][:-len(".cfg")]) + framework: DisentFramework = framework_cls( + model=hydra.utils.instantiate(cfg.model.model_cls), + cfg=hydra.utils.instantiate(cfg.framework.cfg, _convert_='all'), # DisentConfigurable -- convert all OmegaConf objects to python equivalents, eg. DictConfig -> dict + batch_augment=gpu_batch_augment, + ) -def hydra_create_and_update_framework_config(cfg) -> DisentConfigurable.cfg: - # create framework config - this is also kinda hacky - # - we need instantiate_recursive because of optimizer_kwargs, - # otherwise the dictionary is left as an OmegaConf dict - framework_cfg: DisentConfigurable.cfg = hydra.utils.instantiate(cfg.framework.cfg) - # warn if some of the cfg variables were not overridden - missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.cfg.keys()))) + # check if some cfg variables were not overridden + missing_keys = sorted(set(framework.cfg.get_keys()) - (set(cfg.framework.cfg.keys()))) if missing_keys: log.warning(f'{c.RED}Framework {repr(cfg.framework.name)} is missing config keys for:{c.RST}') for k in missing_keys: log.warning(f'{c.RED}{repr(k)}{c.RST}') - # return config - return framework_cfg - - -def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cfg): - # specific handling for experiment, this is HACKY! - # - not supported normally, we need to instantiate to get the class (is there hydra support for this?) - framework_cfg.optimizer = hydra.utils.get_class(framework_cfg.optimizer) - framework_cfg.optimizer_kwargs = dict(framework_cfg.optimizer_kwargs) - # get framework path - assert str.endswith(cfg.framework.cfg._target_, '.cfg'), f'`cfg.framework.cfg._target_` does not end with ".cfg", got: {repr(cfg.framework.cfg._target_)}' - framework_cls = hydra.utils.get_class(cfg.framework.cfg._target_[:-len(".cfg")]) - # create model - model = AutoEncoder( - encoder=hydra.utils.instantiate(cfg.model.encoder_cls), - decoder=hydra.utils.instantiate(cfg.model.decoder_cls), - ) - # initialise the model - model = init_model_weights(model, mode=cfg.settings.model.weight_init) - # create framework - return framework_cls( - model=model, - cfg=framework_cfg, - batch_augment=datamodule.batch_augment, # apply augmentations to batch on GPU which can be faster than via the dataloader - ) + # register schedules to the framework + schedule_items = cfg.schedule.schedule_items + assert isinstance(schedule_items, (dict, DictConfig)), f'`schedule.schedule_items` must be a dictionary, got type: {type(schedule_items)} with value: {repr(schedule_items)}' + if schedule_items: + log.info(f'Registering Schedules:') + for target, schedule in schedule_items.items(): + framework.register_schedule(target, hydra.utils.instantiate(schedule), logging=True) + + return framework + + +def hydra_make_datamodule(cfg): + return HydraDataModule( + data = cfg.dataset.data, # from: dataset + transform = cfg.dataset.transform, # from: dataset + augment = cfg.augment.augment_cls, # from: augment + sampler = cfg.sampling._sampler_.sampler_cls, # from: sampling + # from: run_location + using_cuda = cfg.dsettings.trainer.cuda, + dataloader_kwargs = cfg.datamodule.dataloader, + augment_on_gpu = cfg.datamodule.gpu_augment, + prepare_data_per_node = cfg.datamodule.prepare_data_per_node, + # from: framework.meta + return_indices = cfg.framework.meta.get('requires_indices', False), + return_factors = cfg.framework.meta.get('requires_factors', False), + ) # ========================================================================= # # ACTIONS # @@ -284,15 +238,18 @@ def action_prepare_data(cfg: DictConfig): log.info(f'Starting run at time: {time_string}') # deterministic seed seed(cfg.settings.job.seed) + # register plugins + hydra_register_disent_plugins(cfg) # print useful info log.info(f"Current working directory : {os.getcwd()}") log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") # check data preparation hydra_check_data_paths(cfg) + hydra_check_data_meta(cfg) # print the config log.info(f'Dataset Config Is:\n{make_box_str(OmegaConf.to_yaml({"dataset": cfg.dataset}))}') # prepare data - datamodule = HydraDataModule(cfg) + datamodule = hydra_make_datamodule(cfg) datamodule.prepare_data() @@ -303,8 +260,9 @@ def action_train(cfg: DictConfig): log.info(f'Starting run at time: {time_string}') # -~-~-~-~-~-~-~-~-~-~-~-~- # - # cleanup from old runs: + # -~-~-~-~-~-~-~-~-~-~-~-~- # + try: safe_unset_debug_trainer() safe_unset_debug_logger() @@ -313,75 +271,76 @@ def action_train(cfg: DictConfig): pass # -~-~-~-~-~-~-~-~-~-~-~-~- # - - # deterministic seed - seed(cfg.settings.job.seed) - - # -~-~-~-~-~-~-~-~-~-~-~-~- # - # INITIALISE & SETDEFAULT IN CONFIG + # SETUP # -~-~-~-~-~-~-~-~-~-~-~-~- # # create trainer loggers & callbacks & initialise error messages logger = set_debug_logger(hydra_make_logger(cfg)) + # deterministic seed + seed(cfg.settings.job.seed) + # register plugins + hydra_register_disent_plugins(cfg) # print useful info log.info(f"Current working directory : {os.getcwd()}") log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") - - # check CUDA setting + # checks gpus = hydra_get_gpus(cfg) - - # check data preparation hydra_check_data_paths(cfg) + hydra_check_data_meta(cfg) - # TRAINER CALLBACKS - callbacks = [ - *hydra_get_callbacks(cfg), - *hydra_get_metric_callbacks(cfg), - ] + # -~-~-~-~-~-~-~-~-~-~-~-~- # + # INITIALISE + # -~-~-~-~-~-~-~-~-~-~-~-~- # # HYDRA MODULES - datamodule = HydraDataModule(cfg) - framework_cfg = hydra_create_and_update_framework_config(cfg) - framework = hydra_create_framework(framework_cfg, datamodule, cfg) - - # register schedules - hydra_register_schedules(framework, cfg) + datamodule = hydra_make_datamodule(cfg) + framework = hydra_create_framework(cfg, gpu_batch_augment=datamodule.gpu_batch_augment) + # trainer default kwargs # Setup Trainer trainer = set_debug_trainer(pl.Trainer( + # cannot override these logger=logger, - callbacks=callbacks, gpus=gpus, - # we do this here too so we don't run the final - # metrics, even through we check for it manually. - terminate_on_nan=True, - # TODO: re-enable this in future... something is not compatible - # with saving/checkpointing models + allow enabling from the - # config. Seems like something cannot be pickled? - checkpoint_callback=False, - # additional trainer kwargs - **cfg.trainer, + callbacks=[ + *hydra_get_callbacks(cfg), + *hydra_get_metric_callbacks(cfg), + ModelSummary(max_depth=2), # override default ModelSummary + ], + # additional kwargs from the config + **{ + **dict( + detect_anomaly=False, # this should only be enabled for debugging torch and finding NaN values, slows down execution, not by much though? + enable_checkpointing=False, # TODO: enable this in future + ), + **cfg.trainer, # overrides + } )) # -~-~-~-~-~-~-~-~-~-~-~-~- # - # BEGIN TRAINING + # DEBUG # -~-~-~-~-~-~-~-~-~-~-~-~- # # get config sections print_cfg, boxed_pop = dict(cfg), lambda *keys: make_box_str(OmegaConf.to_yaml({k: print_cfg.pop(k) for k in keys} if keys else print_cfg)) + cfg_str_exp = boxed_pop('action', 'experiment') cfg_str_logging = boxed_pop('logging', 'callbacks', 'metrics') - cfg_str_dataset = boxed_pop('dataset', 'sampling', 'augment') + cfg_str_dataset = boxed_pop('dataset', 'datamodule', 'sampling', 'augment') cfg_str_system = boxed_pop('framework', 'model', 'schedule') cfg_str_settings = boxed_pop('dsettings', 'settings') cfg_str_other = boxed_pop() # print config sections - log.info(f'Final Config For Action: {cfg.action}\n\nLOGGING:{cfg_str_logging}\nDATASET:{cfg_str_dataset}\nSYSTEM:{cfg_str_system}\nTRAINER:{cfg_str_other}\nSETTINGS:{cfg_str_settings}') + log.info(f'Final Config For Action: {cfg.action}\n\nEXPERIMENT:{cfg_str_exp}\nLOGGING:{cfg_str_logging}\nDATASET:{cfg_str_dataset}\nSYSTEM:{cfg_str_system}\nTRAINER:{cfg_str_other}\nSETTINGS:{cfg_str_settings}') - # save hparams TODO: is this a pytorch lightning bug? The trainer should automatically save these if hparams is set? + # -~-~-~-~-~-~-~-~-~-~-~-~- # + # BEGIN TRAINING + # -~-~-~-~-~-~-~-~-~-~-~-~- # + + # save hparams framework.hparams.update(cfg) if trainer.logger: - trainer.logger.log_hyperparams(framework.hparams) + trainer.logger.log_hyperparams(framework.hparams) # TODO: is this a pytorch lightning bug? The trainer should automatically save these if hparams is set? # fit the model # -- if an error/signal occurs while pytorch lightning is @@ -389,15 +348,14 @@ def action_train(cfg: DictConfig): trainer.fit(framework, datamodule=datamodule) # -~-~-~-~-~-~-~-~-~-~-~-~- # - # cleanup this run + # -~-~-~-~-~-~-~-~-~-~-~-~- # + try: wandb.finish() except: pass - # -~-~-~-~-~-~-~-~-~-~-~-~- # - # available actions ACTIONS = { @@ -422,42 +380,13 @@ def run_action(cfg: DictConfig): # ========================================================================= # -# path to root directory containing configs -CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'config')) -# root config existing inside `CONFIG_ROOT`, with '.yaml' appended. -CONFIG_NAME = 'config' - - if __name__ == '__main__': - - # register a custom OmegaConf resolver that allows us to put in a ${exit:msg} that exits the program - # - if we don't register this, the program will still fail because we have an unknown - # resolver. This just prettifies the output. - class ConfigurationError(Exception): - pass - - def _error_resolver(msg: str): - raise ConfigurationError(msg) - - OmegaConf.register_new_resolver('exit', _error_resolver) - - @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) - def hydra_main(cfg: DictConfig): - try: - run_action(cfg) - except Exception as e: - log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=True) - except: - log_error_and_exit(err_type='experiment error', err_msg='', exc_info=True) - - try: - hydra_main() - except KeyboardInterrupt as e: - log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) - except Exception as e: - log_error_and_exit(err_type='hydra error', err_msg=str(e), exc_info=True) - except: - log_error_and_exit(err_type='hydra error', err_msg='', exc_info=True) + # launch the action + hydra_main( + callback=run_action, + config_name='config', + log_level=logging.INFO, + ) # ========================================================================= # diff --git a/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py b/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py new file mode 100644 index 00000000..5a14483e --- /dev/null +++ b/experiment/util/_hydra_searchpath_plugin_/hydra_plugins/searchpath_plugin.py @@ -0,0 +1,33 @@ +""" +This file is currently hacky, and hopefully temporary! See: +https://github.com/facebookresearch/hydra/issues/2001 +""" + +import logging +import os + +from hydra.core.config_search_path import ConfigSearchPath +from hydra.plugins.search_path_plugin import SearchPathPlugin + + +log = logging.getLogger(__name__) + + +class DisentExperimentSearchPathPlugin(SearchPathPlugin): + + def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: + from experiment.util.hydra_main import _DISENT_CONFIG_DIRS + # find paths + paths = [ + *os.environ.get('DISENT_CONFIGS_PREPEND', '').split(';'), + *_DISENT_CONFIG_DIRS, + *os.environ.get('DISENT_CONFIGS_APPEND', '').split(';'), + ] + # print information + log.info(f' [disent-search-path-plugin]: Activated hydra plugin: {self.__class__.__name__}') + log.info(f' [disent-search-path-plugin]: To register more search paths, adjust the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables!') + # add paths + for path in paths: + if path: + log.info(f' [disent-search-path] - {repr(path)}') + search_path.append(provider='disent-searchpath-plugin', path=os.path.abspath(path)) diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index daa4c191..2bf575b6 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -24,6 +24,9 @@ import logging import warnings +from typing import Any +from typing import Dict +from typing import Optional import hydra import torch.utils.data @@ -81,37 +84,54 @@ class HydraDataModule(pl.LightningDataModule): - def __init__(self, hparams: DictConfig): + def __init__( + self, + data: Dict[str, Any], # = dataset.data + sampler: Dict[str, Any], # = sampling._sampler_.sampler_cls + transform: Optional[Dict[str, Any]] = None, # = dataset.transform + augment: Optional[Dict[str, Any]] = None, # = augment.augment_cls + dataloader_kwargs: Optional[Dict[str, Any]] = None, # = dataloader + augment_on_gpu: bool = False, # = dsettings.dataset.gpu_augment + using_cuda: Optional[bool] = False, # = self.hparams.dsettings.trainer.cuda + prepare_data_per_node: bool = True, # DataHooks.prepare_data_per_node + return_indices: bool = False, # = framework.meta.requires_indices + return_factors: bool = False, # = framework.meta.requires_factors + ): super().__init__() - # support pytorch lightning < 1.4 - if not hasattr(self, 'hparams'): - self.hparams = DictConfig(hparams) - else: - self.hparams.update(hparams) + # OVERRIDE: + self.prepare_data_per_node = prepare_data_per_node + # save hparams + self.save_hyperparameters() + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # transform: prepares data from datasets - self.data_transform = hydra.utils.instantiate(self.hparams.dataset.transform) + self.data_transform = hydra.utils.instantiate(transform) assert (self.data_transform is None) or callable(self.data_transform) # input_transform_aug: augment data for inputs, then apply input_transform - self.input_transform = hydra.utils.instantiate(self.hparams.augment.augment_cls) - assert (self.input_transform is None) or callable(self.input_transform) + self.input_transform = hydra.utils.instantiate(augment) + assert (self.input_transform is None) or callable(self.input_transform) # should be: `Callable[[torch.Tensor], torch.Tensor]` + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # batch_augment: augments transformed data for inputs, should be applied across a batch # which version of the dataset we need to use if GPU augmentation is enabled or not. # - corresponds to below in train_dataloader() - if self.hparams.dsettings.dataset.gpu_augment: - # TODO: this is outdated! - self.batch_augment = DisentDatasetTransform(transform=self.input_transform) - warnings.warn('`gpu_augment=True` is outdated and may no longer be equivalent to `gpu_augment=False`') + if augment_on_gpu: + self._gpu_batch_augment = DisentDatasetTransform(transform=self.input_transform) + warnings.warn('`augment_on_gpu=True` is outdated and may no longer be equivalent to `augment_on_gpu=False`') else: - self.batch_augment = None + self._gpu_batch_augment = None + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # datasets initialised in setup() self.dataset_train_noaug: DisentDataset = None self.dataset_train_aug: DisentDataset = None + @property + def gpu_batch_augment(self) -> Optional[DisentDatasetTransform]: + return self._gpu_batch_augment + def prepare_data(self) -> None: # *NB* Do not set model parameters here. # - Instantiate data once to download and prepare if needed. # - trainer.prepare_data_per_node affects this functions behavior per node. - data = dict(self.hparams.dataset.data) + data = dict(self.hparams.data) if 'in_memory' in data: del data['in_memory'] # create the data @@ -124,11 +144,11 @@ def prepare_data(self) -> None: def setup(self, stage=None) -> None: # ground truth data log.info(f'Data - Instance') - data = hydra.utils.instantiate(self.hparams.dataset.data) + data = hydra.utils.instantiate(self.hparams.data) # Wrap the data for the framework some datasets need triplets, pairs, etc. # Augmentation is done inside the frameworks so that it can be done on the GPU, otherwise things are very slow. - self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=None) - self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=self.input_transform) + self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampler), transform=self.data_transform, augment=None, return_indices=self.hparams.return_indices, return_factors=self.hparams.return_factors) + self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampler), transform=self.data_transform, augment=self.input_transform, return_indices=self.hparams.return_indices, return_factors=self.hparams.return_factors) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Training Dataset: @@ -149,21 +169,25 @@ def train_dataloader(self): """ # Select which version of the dataset we need to use if GPU augmentation is enabled or not. # - corresponds to above in __init__() - if self.hparams.dsettings.dataset.gpu_augment: + if self.hparams.augment_on_gpu: dataset = self.dataset_train_noaug else: dataset = self.dataset_train_aug + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # get default kwargs default_kwargs = { 'shuffle': True, # This should usually be TRUE if cuda is enabled. # About 20% faster with the xysquares dataset, RTX 2060 Rev. A, and Intel i7-3930K - 'pin_memory': self.hparams.dsettings.trainer.cuda + 'pin_memory': self.hparams.using_cuda, } # get config kwargs - kwargs = self.hparams.dataloader + kwargs = self.hparams.dataloader_kwargs + if not kwargs: + kwargs = {} # check required keys if ('batch_size' not in kwargs) or ('num_workers' not in kwargs): raise KeyError(f'`dataset.dataloader` must contain keys: ["batch_size", "num_workers"], got: {sorted(kwargs.keys())}') + # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # # create dataloader return torch.utils.data.DataLoader(dataset=dataset, **{**default_kwargs, **kwargs}) diff --git a/experiment/util/hydra_main.py b/experiment/util/hydra_main.py new file mode 100644 index 00000000..ce55611b --- /dev/null +++ b/experiment/util/hydra_main.py @@ -0,0 +1,229 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +import subprocess +import sys +from pathlib import Path +from typing import Callable +from typing import List +from typing import NoReturn +from typing import Optional + +import hydra +from omegaconf import OmegaConf +from omegaconf import DictConfig + +from experiment.util.path_utils import get_current_experiment_number +from experiment.util.path_utils import make_current_experiment_dir +from experiment.util.run_utils import log_error_and_exit + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# VARS # +# ========================================================================= # + + +# experiment/util/_hydra_searchpath_plugin_ +PLUGIN_NAMESPACE = os.path.abspath(os.path.join(__file__, '..', '_hydra_searchpath_plugin_')) + +# experiment/config +EXP_CONFIG_DIR = os.path.abspath(os.path.join(__file__, '../..', 'config')) + +# list of configs +_DISENT_CONFIG_DIRS: List[str] = None + + +# ========================================================================= # +# PATCHING # +# ========================================================================= # + + +def register_searchpath_plugin( + search_dir_main: str = EXP_CONFIG_DIR, + search_dirs_prepend: Optional[List[str]] = None, + search_dirs_append: Optional[List[str]] = None, +): + """ + Patch Hydra: + 1. sets the default search path to `experiment/config` + 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables + NOTE: --config-dir has lower priority than all these, --config-path has higher priority. + + This function can safely be called multiple times + - unless other functions modify these same libs which is unlikely! + """ + # normalise the config paths + if search_dirs_prepend is None: search_dirs_prepend = [] + if search_dirs_append is None: search_dirs_append = [] + assert isinstance(search_dirs_prepend, (tuple, list)) and all((isinstance(d, str) and d) for d in search_dirs_prepend), f'`search_dirs_prepend` must be a list or tuple of non-empty path strings to directories, got: {repr(search_dirs_prepend)}' + assert isinstance(search_dirs_append, (tuple, list)) and all((isinstance(d, str) and d) for d in search_dirs_append), f'`search_dirs_append` must be a list or tuple of non-empty path strings to directories, got: {repr(search_dirs_append)}' + assert isinstance(search_dir_main, str) and search_dir_main, f'`search_dir_main` must be a non-empty path string to a directory, got: {repr(search_dir_main)}' + # get dirs + config_dirs = [*search_dirs_prepend, search_dir_main, *search_dirs_append] + + # check that it is the same as what has previously been registered, otherwise set the directories! + global _DISENT_CONFIG_DIRS + if _DISENT_CONFIG_DIRS is None: + _DISENT_CONFIG_DIRS = config_dirs + else: + assert _DISENT_CONFIG_DIRS == config_dirs, f'Config dirs have already been registered, on additional calls, registered dirs must be the same as previously values!\n- existing: {_DISENT_CONFIG_DIRS}\n- registered: {config_dirs}' + + # register the experiment's search path plugin with disent, using hydras auto-detection + # of folders named `hydra_plugins` contained insided `namespace packages` or rather + # packages that are in the `PYTHONPATH` or `sys.path` + # 1. sets the default search path to those registered above + # 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables + # NOTE: --config-dir has lower priority than all these, --config-path has higher priority. + if PLUGIN_NAMESPACE not in sys.path: + sys.path.insert(0, PLUGIN_NAMESPACE) + + +def register_hydra_resolvers(): + """ + Patch OmegaConf, enabling various config resolvers: + - enable the ${exit:} resolver for omegaconf/hydra + - enable the ${exp_num:} and ${exp_dir:,} resolvers to detect the experiment number + - enable the ${fmt:,} resolver that wraps `str.format` + - enable the ${abspath:} resolver that wraps `hydra.utils.to_absolute_path` formatting relative paths in relation to the original working directory + - enable the ${rsync_dir:/,/} resolver that returns `/`, but first rsync's the two directories! + + This function can safely be called multiple times + - unless other functions modify these same libs which is unlikely! + """ + + + # register a custom OmegaConf resolver that allows us to put in a ${exit:msg} that exits the program + # - if we don't register this, the program will still fail because we have an unknown + # resolver. This just prettifies the output. + if not OmegaConf.has_resolver('exit'): + class ConfigurationError(Exception): + pass + # resolver function + def _error_resolver(msg: str): + raise ConfigurationError(msg) + # patch omegaconf for hydra + OmegaConf.register_new_resolver('exit', _error_resolver) + + # register a custom OmegaConf resolver that allows us to get the next experiment number from a directory + # - ${run_num:} returns the current experiment number + if not OmegaConf.has_resolver('exp_num'): + OmegaConf.register_new_resolver('exp_num', get_current_experiment_number) + # - ${run_dir:,} returns the current experiment folder with the name appended + if not OmegaConf.has_resolver('exp_dir'): + OmegaConf.register_new_resolver('exp_dir', make_current_experiment_dir) + + # register a function that pads an integer to a specified length + # - ${fmt:"{:04d}",42} -> "0042" + if not OmegaConf.has_resolver('fmt'): + OmegaConf.register_new_resolver('fmt', str.format) + + # register hydra helper functions + # - ${abspath:} convert a relative path to an abs path using the original hydra working directory, not the changed experiment dir. + if not OmegaConf.has_resolver('abspath'): + OmegaConf.register_new_resolver('abspath', hydra.utils.to_absolute_path) + + # registry copy directory function + # - useful if datasets are already prepared on a shared drive and need to be copied to a temp drive for example! + if not OmegaConf.has_resolver('rsync_dir'): + def rsync_dir(src: str, dst: str) -> str: + src, dst = Path(src), Path(dst) + # checks + assert src.name and src.is_absolute(), f'src path must be absolute and not the root: {repr(str(src))}' + assert dst.name and dst.is_absolute(), f'dst path must be absolute and not the root: {repr(str(dst))}' + assert src.name == dst.name, f'src and dst paths must point to dirs with the same names: src.name={repr(src.name)}, dst.name={repr(dst.name)}' + # synchronize dirs + logging.info(f'rsync files:\n- src={repr(str(src))}\n- dst={repr(str(dst))}') + # create the parent dir and copy files into the parent + dst.parent.mkdir(parents=True, exist_ok=True) + returncode = subprocess.Popen(['rsync', '-avh', str(src), str(dst.parent)]).wait() + if returncode != 0: + raise RuntimeError('Failed to rsync files!') + # return the destination dir + return str(dst) + # REGISTER + OmegaConf.register_new_resolver('rsync_dir', rsync_dir) + + +# ========================================================================= # +# RUN HYDRA # +# ========================================================================= # + + +def patch_hydra( + # config search path + search_dir_main: str = EXP_CONFIG_DIR, + search_dirs_prepend: Optional[List[str]] = None, + search_dirs_append: Optional[List[str]] = None, +): + # Patch Hydra and OmegaConf: + register_searchpath_plugin(search_dir_main=search_dir_main, search_dirs_prepend=search_dirs_prepend, search_dirs_append=search_dirs_append) + register_hydra_resolvers() + + +def hydra_main( + callback: Callable[[DictConfig], NoReturn], + config_name: str = 'config', + # config search path + search_dir_main: str = EXP_CONFIG_DIR, + search_dirs_prepend: Optional[List[str]] = None, + search_dirs_append: Optional[List[str]] = None, + # logging + log_level: Optional[int] = logging.INFO, + log_exc_info_callback: bool = True, + log_exc_info_hydra: bool = False, +): + # manually set log level before hydra initialises! + if log_level is not None: + logging.basicConfig(level=log_level) + + # Patch Hydra and OmegaConf: + patch_hydra(search_dir_main=search_dir_main, search_dirs_prepend=search_dirs_prepend, search_dirs_append=search_dirs_append) + + @hydra.main(config_path=None, config_name=config_name) + def _hydra_main(cfg: DictConfig): + try: + callback(cfg) + except Exception as e: + log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=log_exc_info_callback) + except: + log_error_and_exit(err_type='experiment error', err_msg='', exc_info=log_exc_info_callback) + + try: + _hydra_main() + except KeyboardInterrupt as e: + log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) + except Exception as e: + log_error_and_exit(err_type='hydra error', err_msg=str(e), exc_info=log_exc_info_hydra) + except: + log_error_and_exit(err_type='hydra error', err_msg='', exc_info=log_exc_info_hydra) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/experiment/util/path_utils.py b/experiment/util/path_utils.py new file mode 100644 index 00000000..57858aec --- /dev/null +++ b/experiment/util/path_utils.py @@ -0,0 +1,150 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +import re +from pathlib import Path +from typing import Optional +from typing import Tuple +from typing import Union + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# PATH HELPERS # +# ========================================================================= # + + +_EXPERIMENT_SEP = '_' +_EXPERIMENT_RGX = re.compile(f'^([0-9]+)({_EXPERIMENT_SEP}.+)?$') + + +def get_max_experiment_number(root_dir: str, return_path: bool = False) -> Union[int, Tuple[int, Optional[str]]]: + """ + Get the next experiment number in the specified directory. Experiment directories + all start with a numerical value. + - eg. "1", "00002", "3_name", "00042_name" are all valid subdirectories. + - eg. "name", "name_1", "name_00001", "99999_image.png" are all invalid and are + ignored. Either their name format is wrong or they are a file. + + If all the above directories are all used as an example, then this function will + return the value 42 corresponding to "00042_name" + """ + # check the dirs exist + if not os.path.exists(root_dir): + raise FileNotFoundError(f'The given experiments directory does not exist: {repr(root_dir)} ({repr(os.path.abspath(root_dir))})') + elif not os.path.isdir(root_dir): + raise NotADirectoryError(f'The given experiments path exists, but is not a directory: {repr(root_dir)} ({repr(os.path.abspath(root_dir))})') + # linear search over each file in the dir + max_num, max_path = 0, None + for file in os.listdir(root_dir): + # skip if not a directory + if not os.path.isdir(os.path.join(root_dir, file)): + continue + # skip if the file name does not match + match = _EXPERIMENT_RGX.search(file) + if not match: + continue + # update the maximum number + num, _ = match.groups() + num = int(num) + if num > max_num: + max_num, max_path = num, file + # done! + if return_path: + return max_num, max_path + return max_num + + +_CURRENT_EXPERIMENT_NUM: Optional[int] = None +_CURRENT_EXPERIMENT_DIR: Optional[str] = None + + +def get_current_experiment_number(root_dir: str) -> int: + """ + Get the next experiment number from the experiment directory, and cache + the result for future calls of this function for the current instance of the program. + - The next time the program is run, this value will differ. + + For example, if the `root_dir` contains the directories: "00001_name", "00041", then + this function will return the next value which is `42` on all subsequent calls, even + if a directory for experiment 42 is created during the current program's lifetime. + """ + global _CURRENT_EXPERIMENT_NUM + if _CURRENT_EXPERIMENT_NUM is None: + _CURRENT_EXPERIMENT_NUM = get_max_experiment_number(root_dir, return_path=False) + 1 + return _CURRENT_EXPERIMENT_NUM + + +def get_current_experiment_dir(root_dir: str, name: Optional[str] = None) -> str: + """ + Like `get_current_experiment_number` which computes the next experiment number, this + function computes the next experiment path, which appends a name to the computed number. + + The result is cached for the lifetime of the program, however, on subsequent calls of + this function, the computed name must always match the original value otherwise an + error is thrown! This is to prevent experiments with duplicate numbers from being created! + """ + if name is not None: + assert Path(name).name == name, f'The given name is not valid: {repr(name)}' + # make the dirname & normalise the path + num = get_current_experiment_number(root_dir) + dir_name = f'{num:05d}{_EXPERIMENT_SEP}{name}' if name else f'{num:05d}' + exp_dir = os.path.abspath(os.path.join(root_dir, dir_name)) + # cache the experiment name or check against the existing cache + global _CURRENT_EXPERIMENT_DIR + if _CURRENT_EXPERIMENT_DIR is None: + _CURRENT_EXPERIMENT_DIR = exp_dir + if exp_dir != _CURRENT_EXPERIMENT_DIR: + raise RuntimeError(f'Current experiment directory has already been set: {repr(_CURRENT_EXPERIMENT_DIR)} This does not match what was computed: {repr(exp_dir)}') + # done! + return _CURRENT_EXPERIMENT_DIR + + +def make_current_experiment_dir(root_dir: str, name: Optional[str] = None) -> str: + """ + Like `get_current_experiment_dir`, but create any of the directories if needed. + - Both the `root_dir` and the computed subdir for the current experiment will be created. + """ + root_dir = os.path.abspath(root_dir) + # make the root directory if it does not exist! + if not os.path.exists(root_dir): + log.info(f'root experiments directory does not exist, creating... {repr(root_dir)}') + os.makedirs(root_dir, exist_ok=True) + # get the current dir + current_dir = get_current_experiment_dir(root_dir, name) + # make the current dir + if not os.path.exists(current_dir): + log.info(f'current experiment directory does not exist, creating... {repr(current_dir)}') + os.makedirs(current_dir, exist_ok=True) + # done! + return current_dir + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/requirements-test.txt b/requirements-test.txt index af5c6933..79de0ba6 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,4 +2,21 @@ pytest>=6.2.4 pytest-cov>=2.12.1 # requires pytorch to be installed first (duplicated in requirements-experiment.txt) +# - we need `nvcc` to be installed first, otherwise GPU kernel extensions will not be +# compiled and this error will silently be skipped. If you get an error such as +# $ conda install -c nvidia cuda-nvcc +# - Make sure that the version of torch corresponds to the version of `nvcc`, torch needs +# to be compiled with the same version! Install the correct version from: +# https://pytorch.org/get-started/locally/ By default torch compiled with 10.2 is installed, +# but `nvcc` will probably want to install 11. +# CUDA 10.2 (as of 2022-03-15) EITHER OF: +# $ pip3 install torch torchvision torchaudio +# $ conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch +# CUDA 11.3 (as of 2022-03-15) EITHER OF: +# $ pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html +# $ conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch +# - I personally just manage my cuda version manually, installing the correct cudatoolkit from: https://developer.nvidia.com/cuda-toolkit-archive +# Then making sure that: +# PATH contains: "/usr/local/cuda/bin" +# LD_LIBRARY_PATH contains: "/usr/local/cuda/lib64" torchsort>=0.1.4 diff --git a/setup.py b/setup.py index 99db0047..216eeb25 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.3.4", + version="0.4.0", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(), diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 0ec7593f..fec35bba 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -22,13 +22,15 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import logging import os import os.path -import hydra import pytest import experiment.run as experiment_run +from experiment.util.hydra_main import hydra_main +from tests.util import temp_environ from tests.util import temp_sys_args @@ -37,19 +39,28 @@ # ========================================================================= # -@pytest.mark.parametrize('args', [ - ['run_action=skip'], - ['run_action=prepare_data'], - ['run_action=train'], +@pytest.mark.parametrize(('env', 'args'), [ + # test the standard configs + (dict(), ['run_action=skip']), + (dict(), ['run_action=prepare_data']), + (dict(), ['run_action=train']), ]) -def test_experiment_run(args): +def test_experiment_run(env, args): + # show full errors in hydra os.environ['HYDRA_FULL_ERROR'] = '1' - # TODO: why does this not work when config_path is absolute? - # ie. config_path=os.path.join(os.path.dirname(experiment_run.__file__), 'config') - with temp_sys_args([experiment_run.__file__, *args]): - hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.run_action) - hydra_main() + # temporarily set the environment and the arguments + with temp_environ(env), temp_sys_args([experiment_run.__file__, *args]): + # run the hydra experiment + # 1. sets the default search path to `experiment/config` + # 2. add to the search path with the `DISENT_CONFIGS_PREPEND` and `DISENT_CONFIGS_APPEND` environment variables + # 3. enable the ${exit:} and various other resolvers for omegaconf/hydra + hydra_main( + callback=experiment_run.run_action, + config_name='config_test', + log_level=logging.DEBUG, + ) + # ========================================================================= # diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 91c1750c..f454050f 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -22,6 +22,7 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import pickle from dataclasses import asdict from functools import partial @@ -47,7 +48,7 @@ # ========================================================================= # -@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], [ +_TEST_FRAMEWORKS = [ # AE - unsupervised (Ae, dict(), XYObjectData), # AE - weakly supervised @@ -69,8 +70,11 @@ (AdaGVaeMinimal, dict(), XYObjectData), # VAE - supervised (TripletVae, dict(), XYObjectData), - (TripletVae, dict(disable_decoder=True, disable_reg_loss=True, disable_posterior_scale=0.5), XYObjectData), -]) + (TripletVae, dict(detach_decoder=True, disable_reg_loss=True), XYObjectData), +] + + +@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], _TEST_FRAMEWORKS) def test_frameworks(Framework, cfg_kwargs, Data): DataSampler = { 1: GroundTruthSingleSampler, @@ -90,9 +94,29 @@ def test_frameworks(Framework, cfg_kwargs, Data): cfg=Framework.cfg(**cfg_kwargs) ) + # test pickling before training + pickle.dumps(framework) + + # train! trainer = pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, fast_dev_run=True) trainer.fit(framework, dataloader) + # test pickling after training, something may have changed! + pickle.dumps(framework) + + +@pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], _TEST_FRAMEWORKS) +def test_framework_pickling(Framework, cfg_kwargs, Data): + framework = Framework( + model=AutoEncoder( + encoder=EncoderLinear(x_shape=(64, 64, 3), z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1), + decoder=DecoderLinear(x_shape=(64, 64, 3), z_size=6), + ), + cfg=Framework.cfg(**cfg_kwargs) + ) + # test pickling! + pickle.dumps(framework) + def test_framework_config_defaults(): # import torch @@ -102,8 +126,7 @@ def test_framework_config_defaults(): optimizer_kwargs=None, recon_loss='mse', disable_aug_loss=False, - disable_decoder=False, - disable_posterior_scale=None, + detach_decoder=False, disable_rec_loss=False, disable_reg_loss=False, loss_reduction='mean', @@ -116,8 +139,7 @@ def test_framework_config_defaults(): optimizer_kwargs=None, recon_loss='bce', disable_aug_loss=False, - disable_decoder=False, - disable_posterior_scale=None, + detach_decoder=False, disable_rec_loss=False, disable_reg_loss=False, loss_reduction='mean', diff --git a/tests/test_math.py b/tests/test_math.py index 69c2e4cd..e4962f33 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -37,11 +37,16 @@ from disent.nn.functional import torch_cov_matrix from disent.nn.functional import torch_dct from disent.nn.functional import torch_dct2 +from disent.nn.functional import torch_dist_hamming from disent.nn.functional import torch_gaussian_kernel_2d from disent.nn.functional import torch_idct from disent.nn.functional import torch_idct2 from disent.nn.functional import torch_mean_generalized +from disent.nn.functional import torch_norm +from disent.nn.functional import torch_dist from disent.dataset.transform import ToImgTensorF32 +from disent.nn.functional import torch_norm_euclidean +from disent.nn.functional import torch_norm_manhattan from disent.util import to_numpy @@ -84,12 +89,72 @@ def test_generalised_mean(): assert torch.allclose(torch_mean_generalized(xs, p=-1), torch.as_tensor(hmean(xs, axis=None))) # scipy default axis is 0 # min max - assert torch.allclose(torch_mean_generalized(xs, p='maximum', dim=1), torch.max(xs, dim=1).values) - assert torch.allclose(torch_mean_generalized(xs, p='minimum', dim=1), torch.min(xs, dim=1).values) - assert torch.allclose(torch_mean_generalized(xs, p=np.inf, dim=1), torch.max(xs, dim=1).values) - assert torch.allclose(torch_mean_generalized(xs, p=-np.inf, dim=1), torch.min(xs, dim=1).values) + assert torch.allclose(torch_mean_generalized(xs, p='maximum', dim=1), torch.amax(xs, dim=1)) + assert torch.allclose(torch_mean_generalized(xs, p='minimum', dim=1), torch.amin(xs, dim=1)) + assert torch.allclose(torch_mean_generalized(xs, p=np.inf, dim=1), torch.amax(xs, dim=1)) + assert torch.allclose(torch_mean_generalized(xs, p=-np.inf, dim=1), torch.amin(xs, dim=1)) +def test_p_norm(): + xs = torch.abs(torch.randn(2, 1000, 3, dtype=torch.float64)) + + inf = float('inf') + + # torch equivalents + assert torch.allclose(torch_norm(xs, p=1, dim=-1), torch.linalg.norm(xs, ord=1, dim=-1)) + assert torch.allclose(torch_norm(xs, p=2, dim=-1), torch.linalg.norm(xs, ord=2, dim=-1)) + assert torch.allclose(torch_norm(xs, p='maximum', dim=-1), torch.linalg.norm(xs, ord=inf, dim=-1)) + + # torch equivalents -- less than zero [FAIL] + with pytest.raises(ValueError, match='p-norm cannot have a p value less than 1'): + assert torch.allclose(torch_norm(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1)) + with pytest.raises(ValueError, match='p-norm cannot have a p value less than 1'): + assert torch.allclose(torch_norm(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1)) + + # torch equivalents -- less than zero + assert torch.allclose(torch_dist(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1)) + assert torch.allclose(torch_dist(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1)) + assert torch.allclose(torch_dist(xs, p=0, dim=-1), torch.linalg.norm(xs, ord=0, dim=-1)) + assert torch.allclose(torch_dist(xs, p='minimum', dim=-1), torch.linalg.norm(xs, ord=-inf, dim=-1)) + + # test other axes + ys = torch.flatten(torch.moveaxis(xs, 0, -1), start_dim=1, end_dim=-1) + assert torch.allclose(torch_norm(xs, p=2, dim=[0, -1]), torch.linalg.norm(ys, ord=2, dim=-1)) + ys = torch.flatten(xs, start_dim=1, end_dim=-1) + assert torch.allclose(torch_norm(xs, p=1, dim=[-2, -1]), torch.linalg.norm(ys, ord=1, dim=-1)) + ys = torch.flatten(torch.moveaxis(xs, -1, 0), start_dim=1, end_dim=-1) + assert torch.allclose(torch_dist(xs, p=0, dim=[0, 1]), torch.linalg.norm(ys, ord=0, dim=-1)) + + # check equal names + assert torch.allclose(torch_dist(xs, dim=1, p='euclidean'), torch_norm_euclidean(xs, dim=1)) + assert torch.allclose(torch_dist(xs, dim=1, p=2), torch_norm_euclidean(xs, dim=1)) + assert torch.allclose(torch_dist(xs, dim=1, p='manhattan'), torch_norm_manhattan(xs, dim=1)) + assert torch.allclose(torch_dist(xs, dim=1, p=1), torch_norm_manhattan(xs, dim=1)) + assert torch.allclose(torch_dist(xs, dim=1, p='hamming'), torch_dist_hamming(xs, dim=1)) + assert torch.allclose(torch_dist(xs, dim=1, p=0), torch_dist_hamming(xs, dim=1)) + + # check axes + assert torch_dist(xs, dim=1, p=2, keepdim=False).shape == (2, 3) + assert torch_dist(xs, dim=1, p=2, keepdim=True).shape == (2, 1, 3) + assert torch_dist(xs, dim=-1, p=1, keepdim=False).shape == (2, 1000) + assert torch_dist(xs, dim=-1, p=1, keepdim=True).shape == (2, 1000, 1) + assert torch_dist(xs, dim=0, p=0, keepdim=False).shape == (1000, 3) + assert torch_dist(xs, dim=0, p=0, keepdim=True).shape == (1, 1000, 3) + assert torch_dist(xs, dim=[0, -1], p=-inf, keepdim=False).shape == (1000,) + assert torch_dist(xs, dim=[0, -1], p=-inf, keepdim=True).shape == (1, 1000, 1) + assert torch_dist(xs, dim=[0, 1], p=inf, keepdim=False).shape == (3,) + assert torch_dist(xs, dim=[0, 1], p=inf, keepdim=True).shape == (1, 1, 3) + + # check norm over all + assert torch_dist(xs, dim=None, p=0, keepdim=False).shape == () + assert torch_dist(xs, dim=None, p=0, keepdim=True).shape == (1, 1, 1) + assert torch_dist(xs, dim=None, p=1, keepdim=False).shape == () + assert torch_dist(xs, dim=None, p=1, keepdim=True).shape == (1, 1, 1) + assert torch_dist(xs, dim=None, p=2, keepdim=False).shape == () + assert torch_dist(xs, dim=None, p=2, keepdim=True).shape == (1, 1, 1) + assert torch_dist(xs, dim=None, p=inf, keepdim=False).shape == () + assert torch_dist(xs, dim=None, p=inf, keepdim=True).shape == (1, 1, 1) + def test_dct(): x = torch.randn(128, 3, 64, 32, dtype=torch.float64) @@ -141,7 +206,7 @@ def test_fft_conv2d(): kernel = torch_gaussian_kernel_2d(sigma=i) out_cnv = torch_conv2d_channel_wise(signal=batch, kernel=kernel)[0] out_fft = torch_conv2d_channel_wise_fft(signal=batch, kernel=kernel)[0] - assert torch.max(torch.abs(out_cnv - out_fft)) < 1e-6 + assert torch.amax(torch.abs(out_cnv - out_fft)) < 1e-6 # ========================================================================= # diff --git a/tests/test_registry.py b/tests/test_registry.py index c173d969..1bde1aa3 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -22,8 +22,8 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from disent.registry import REGISTRIES +import pytest +import disent.registry as R # ========================================================================= # @@ -32,31 +32,28 @@ COUNTS = { - 'DATASETS': 6, + 'DATASETS': 10, 'SAMPLERS': 8, 'FRAMEWORKS': 10, - 'RECON_LOSSES': 6, - 'LATENT_DISTS': 2, + 'RECON_LOSSES': 9, + 'LATENT_HANDLERS': 2, 'OPTIMIZERS': 30, 'METRICS': 5, 'SCHEDULES': 5, 'MODELS': 8, + 'KERNELS': 2, } - -def test_registry_loading(): +@pytest.mark.parametrize('registry_key', COUNTS.keys()) +def test_registry_loading(registry_key): # load everything and check the counts - total = 0 - for registry in REGISTRIES: - count = 0 - for name in REGISTRIES[registry]: - loaded = REGISTRIES[registry][name] - count += 1 - total += 1 - assert COUNTS[registry] == count, f'invalid count for: {registry}' - assert total == sum(COUNTS.values()), f'invalid total' + count = 0 + for example in R.REGISTRIES[registry_key]: + loaded = R.REGISTRIES[registry_key][example] + count += 1 + assert count == COUNTS[registry_key], f'invalid count for: {registry_key}' # ========================================================================= # diff --git a/tests/test_to_img.py b/tests/test_to_img.py new file mode 100644 index 00000000..5ecd68ae --- /dev/null +++ b/tests/test_to_img.py @@ -0,0 +1,328 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2022 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +import numpy as np +import pytest +import torch +from PIL import Image + +from disent.util.visualize.vis_img import _torch_to_images_normalise_args +from disent.util.visualize.vis_img import numpy_to_images +from disent.util.visualize.vis_img import numpy_to_pil_images +from disent.util.visualize.vis_img import torch_to_images +from disent.util.visualize.vis_img import _ALLOWED_DTYPES + + +# ========================================================================= # +# Tests # +# ========================================================================= # + + +def test_torch_to_images_basic(): + inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) + inp_uint8 = (inp_float * 127 + 63).to(torch.uint8) + # check runs + out = torch_to_images(inp_float) + assert out.dtype == torch.uint8 + out = torch_to_images(inp_uint8) + assert out.dtype == torch.uint8 + out = torch_to_images(inp_float, in_dtype=None, out_dtype=None) + assert out.dtype == inp_float.dtype + out = torch_to_images(inp_uint8, in_dtype=None, out_dtype=None) + assert out.dtype == inp_uint8.dtype + + +def test_torch_to_images_permutations(): + inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) + inp_uint8 = (inp_float * 127 + 63).to(torch.uint8) + + # general checks + def check_all(inputs, in_dtype=None): + float_results, int_results = [], [] + for out_dtype in _ALLOWED_DTYPES: + out = torch_to_images(inputs, in_dtype=in_dtype, out_dtype=out_dtype) + stats = torch.stack([out.min().to(torch.float64), out.max().to(torch.float64), out.to(dtype=torch.float64).mean()]) + (float_results if out_dtype.is_floating_point else int_results).append(stats) + for a, b in zip(float_results[:-1], float_results[1:]): assert torch.allclose(a, b) + for a, b in zip(int_results[:-1], int_results[1:]): assert torch.allclose(a, b) + + # check type permutations + check_all(inp_float, torch.float32) + check_all(inp_uint8, torch.uint8) + + +def test_torch_to_images_preserve_type(): + for dtype in _ALLOWED_DTYPES: + tensor = (torch.rand(8, 3, 64, 64) * (1 if dtype.is_floating_point else 255)).to(dtype) + out = torch_to_images(tensor, in_dtype=dtype, out_dtype=dtype, clamp_mode='warn') + assert out.dtype == dtype + + +def test_torch_to_images_arg_helper(): + assert _torch_to_images_normalise_args((64, 128, 3), torch.uint8, 'HWC', 'CHW', None, None) == ((-1, -3, -2), torch.uint8, torch.uint8, -3) + assert _torch_to_images_normalise_args((64, 128, 3), torch.uint8, 'HWC', 'HWC', None, None) == ((-3, -2, -1), torch.uint8, torch.uint8, -1) + + +def test_torch_to_images_invalid_args(): + inp_float = torch.rand(8, 3, 64, 64, dtype=torch.float32) + + # check tensor + with pytest.raises(TypeError, match="images must be of type"): + torch_to_images(tensor=None) + with pytest.raises(ValueError, match='dim "C", required: 1 or 3'): + torch_to_images(tensor=torch.rand(8, 2, 16, 16, dtype=torch.float32)) + with pytest.raises(ValueError, match='dim "C", required: 1 or 3'): + torch_to_images(tensor=torch.rand(8, 16, 16, 3, dtype=torch.float32)) + with pytest.raises(ValueError, match='images must have 3 or more dimensions corresponding to'): + torch_to_images(tensor=torch.rand(16, 16, dtype=torch.float32)) + + # check dims + with pytest.raises(TypeError, match="in_dims must be of type"): + torch_to_images(inp_float, in_dims=None) + with pytest.raises(TypeError, match="out_dims must be of type"): + torch_to_images(inp_float, out_dims=None) + with pytest.raises(KeyError, match="in_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"): + torch_to_images(inp_float, in_dims='INVALID') + with pytest.raises(KeyError, match="out_dims contains the symbols: 'INVALID', must contain only permutations of: 'CHW'"): + torch_to_images(inp_float, out_dims='INVALID') + with pytest.raises(KeyError, match="in_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"): + torch_to_images(inp_float, in_dims='CHWW') + with pytest.raises(KeyError, match="out_dims contains the symbols: 'CHWW', must contain only permutations of: 'CHW'"): + torch_to_images(inp_float, out_dims='CHWW') + + # check dtypes + with pytest.raises(TypeError, match="images dtype: torch.float32 does not match in_dtype: torch.uint8"): + torch_to_images(inp_float, in_dtype=torch.uint8) + with pytest.raises(TypeError, match='in_dtype is not allowed'): + torch_to_images(inp_float, in_dtype=torch.complex64) + with pytest.raises(TypeError, match='out_dtype is not allowed'): + torch_to_images(inp_float, out_dtype=torch.complex64) + with pytest.raises(TypeError, match='in_dtype is not allowed'): + torch_to_images(inp_float, in_dtype=torch.float16) + with pytest.raises(TypeError, match='out_dtype is not allowed'): + torch_to_images(inp_float, out_dtype=torch.float16) + + +def _check(target_shape, target_dtype, img, m=None, M=None): + assert img.dtype == target_dtype + assert img.shape == target_shape + if isinstance(img, torch.Tensor): img = img.numpy() + if m is not None: assert np.isclose(img.min(), m), f'min mismatch: {img.min()} (actual) != {m} (expected)' + if M is not None: assert np.isclose(img.max(), M), f'max mismatch: {img.max()} (actual) != {M} (expected)' + + +def test_torch_to_images_adv(): + # CHW + nchw_float = torch.rand(8, 3, 64, 32, dtype=torch.float32) + nchw_uint8 = torch.randint(0, 255, (8, 3, 64, 32), dtype=torch.uint8) + # HWC + nhwc_float = torch.rand(8, 64, 32, 3, dtype=torch.float32) + nhwc_uint8 = torch.randint(0, 255, (8, 64, 32, 3), dtype=torch.uint8) + + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float)) # make sure default for numpy is CHW + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8)) # make sure default for numpy is CHW + + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC')) + + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW', 'HWC')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'HWC')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC')) + + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_float, 'CHW', 'CHW')) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW')) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'CHW')) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW')) + + # random permute + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW')) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW')) + _check((8, 3, 32, 64), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CWH')) + _check((8, 3, 32, 64), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CWH')) + + _check((8, 64, 3, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HCW')) + _check((8, 64, 3, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HCW')) + _check((8, 32, 3, 64), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'WCH')) + _check((8, 32, 3, 64), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'WCH')) + + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC')) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC')) + _check((8, 32, 64, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'WHC')) + _check((8, 32, 64, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'WHC')) + + _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float32)) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float32)) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float32)) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float32)) + + _check((8, 64, 32, 3), torch.float64, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float64)) + _check((8, 64, 32, 3), torch.float64, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float64)) + _check((8, 64, 32, 3), torch.float64, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float64)) + _check((8, 64, 32, 3), torch.float64, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float64)) + + # random, but chance of this failing is almost impossible + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255) + _check((8, 64, 32, 3), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255) + + # random, but chance of this failing is almost impossible + _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_float, 'CHW', 'HWC', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'HWC', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_float, 'HWC', 'HWC', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1) + _check((8, 64, 32, 3), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'HWC', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1) + + # random, but chance of this failing is almost impossible + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_float, 'CHW', 'CHW', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nchw_uint8, 'CHW', 'CHW', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_float, 'HWC', 'CHW', out_dtype=torch.uint8, in_min=0.25, in_max=0.75), m=0, M=255) + _check((8, 3, 64, 32), torch.uint8, torch_to_images(nhwc_uint8, 'HWC', 'CHW', out_dtype=torch.uint8, in_min=64, in_max=192), m=0, M=255) + + # random, but chance of this failing is almost impossible + _check((8, 3, 64, 32), torch.float32, torch_to_images(nchw_float, 'CHW', 'CHW', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1) + _check((8, 3, 64, 32), torch.float32, torch_to_images(nchw_uint8, 'CHW', 'CHW', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1) + _check((8, 3, 64, 32), torch.float32, torch_to_images(nhwc_float, 'HWC', 'CHW', out_dtype=torch.float32, in_min=0.25, in_max=0.75), m=0, M=1) + _check((8, 3, 64, 32), torch.float32, torch_to_images(nhwc_uint8, 'HWC', 'CHW', out_dtype=torch.float32, in_min=64, in_max=192), m=0, M=1) + + # check clamping + with pytest.raises(ValueError, match='is outside of the required range'): + torch_to_images(nchw_float, 'CHW', out_dtype=torch.float32, clamp_mode='error', in_min=0.25, in_max=0.75) + with pytest.raises(ValueError, match='is outside of the required range'): + torch_to_images(nchw_uint8, 'CHW', out_dtype=torch.float32, clamp_mode='error', in_min=64, in_max=192) + with pytest.raises(ValueError, match='is outside of the required range'): + torch_to_images(nhwc_float, 'HWC', out_dtype=torch.float32, clamp_mode='error', in_min=0.25, in_max=0.75) + with pytest.raises(ValueError, match='is outside of the required range'): + torch_to_images(nhwc_uint8, 'HWC', out_dtype=torch.float32, clamp_mode='error', in_min=64, in_max=192) + with pytest.raises(KeyError, match="invalid clamp mode: 'asdf'"): + torch_to_images(nhwc_uint8, 'HWC', out_dtype=torch.float32, clamp_mode='asdf', in_min=64, in_max=192) + + +def test_numpy_to_pil_image(): + # CHW + nchw_float = np.random.rand(8, 3, 64, 32) + nchw_uint8 = np.random.randint(0, 255, (8, 3, 64, 32), dtype='uint8') + + # HWC + nhwc_float = np.random.rand(8, 64, 32, 3) + nhwc_uint8 = np.random.randint(0, 255, (8, 64, 32, 3), dtype='uint8') + + with pytest.raises(ValueError, match='images do not have the correct number of channels for dim "C"'): + numpy_to_images(nchw_float) # make sure default for numpy is HWC + with pytest.raises(ValueError, match='images do not have the correct number of channels for dim "C"'): + numpy_to_images(nchw_uint8) # make sure default for numpy is HWC + + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float)) # make sure default for numpy is HWC + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8)) # make sure default for numpy is HWC + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW')) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW')) + + for pil_images in [ + numpy_to_pil_images(nchw_float, 'CHW'), + numpy_to_pil_images(nhwc_float), + ]: + assert isinstance(pil_images, np.ndarray) + assert pil_images.shape == (8,) + for pil_image in pil_images: + pil_image: Image.Image + assert isinstance(pil_image, Image.Image) + assert pil_image.width == 32 + assert pil_image.height == 64 + + # single image should be returned as an array of shape () + pil_image: np.ndarray = numpy_to_pil_images(np.random.rand(64, 32, 3)) + assert pil_image.shape == () + assert isinstance(pil_image, np.ndarray) + assert isinstance(pil_image.tolist(), Image.Image) + + # check arb size + pil_images: np.ndarray = numpy_to_pil_images(np.random.rand(4, 5, 2, 16, 32, 3)) + assert pil_images.shape == (4, 5, 2) + + +def test_numpy_image_min_max(): + # CHW + nchw_float = np.random.rand(8, 3, 64, 32) + nchw_uint8 = np.random.randint(0, 255, (8, 3, 64, 32), dtype='uint8') + + # HWC + nhwc_float = np.random.rand(8, 64, 32, 3) + nhwc_uint8 = np.random.randint(0, 255, (8, 64, 32, 3), dtype='uint8') + + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=0, in_max=1)) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=0, in_max=255)) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=0, in_max=1)) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=0, in_max=255)) + + # OUT: HWC + + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=(0, 0, 0), in_max=(1, 1, 1))) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=(0, 0, 0), in_max=(255, 255, 255))) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1))) + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255))) + + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_float, 'HWC', in_min=(0,), in_max=(1,))) # should maybe disable this from working? + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', in_min=(0,), in_max=(255,))) # should maybe disable this from working? + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_float, 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working? + _check((8, 64, 32, 3), 'uint8', numpy_to_images(nchw_uint8, 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working? + + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', in_min=(0,), in_max=(1,))) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', in_min=(0,), in_max=(255,))) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', in_min=(0,), in_max=(1,))) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', in_min=(0,), in_max=(255,))) + + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', in_min=0, in_max=1)) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', in_min=0, in_max=255)) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', in_min=0, in_max=1)) + _check((8, 64, 32, 1), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', in_min=0, in_max=255)) + + # OUT: CHW + + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_float, 'HWC', 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1))) + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255))) + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_float, 'CHW', 'CHW', in_min=(0, 0, 0), in_max=(1, 1, 1))) + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_uint8, 'CHW', 'CHW', in_min=(0, 0, 0), in_max=(255, 255, 255))) + + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_float, 'HWC', 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working? + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nhwc_uint8, 'HWC', 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working? + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_float, 'CHW', 'CHW', in_min=(0,), in_max=(1,))) # should maybe disable this from working? + _check((8, 3, 64, 32), 'uint8', numpy_to_images(nchw_uint8, 'CHW', 'CHW', in_min=(0,), in_max=(255,))) # should maybe disable this from working? + + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', 'CHW', in_min=(0,), in_max=(1,))) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', 'CHW', in_min=(0,), in_max=(255,))) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', 'CHW', in_min=(0,), in_max=(1,))) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', 'CHW', in_min=(0,), in_max=(255,))) + + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_float[:, :, :, 0:1], 'HWC', 'CHW', in_min=0, in_max=1)) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nhwc_uint8[:, :, :, 0:1], 'HWC', 'CHW', in_min=0, in_max=255)) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_float[:, 0:1, :, :], 'CHW', 'CHW', in_min=0, in_max=1)) + _check((8, 1, 64, 32), 'uint8', numpy_to_images(nchw_uint8[:, 0:1, :, :], 'CHW', 'CHW', in_min=0, in_max=255)) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/tests/util.py b/tests/util.py index bcff6941..1bc152c8 100644 --- a/tests/util.py +++ b/tests/util.py @@ -21,10 +21,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + import contextlib import os import sys from contextlib import contextmanager +from typing import Any +from typing import Dict # ========================================================================= # @@ -58,12 +61,35 @@ def temp_wd(new_wd): @contextlib.contextmanager def temp_sys_args(new_argv): + # TODO: should this copy values? old_argv = sys.argv sys.argv = new_argv yield sys.argv = old_argv +@contextmanager +def temp_environ(environment: Dict[str, Any]): + # TODO: should this copy values? -- could use unittest.mock.patch.dict(...) + # save the old environment + existing_env = {} + for k in environment: + if k in os.environ: + existing_env[k] = os.environ[k] + # update the environment + os.environ.update(environment) + # run the context + try: + yield + finally: + # restore the original environment + for k in environment: + if k in existing_env: + os.environ[k] = existing_env[k] + else: + del os.environ[k] + + # ========================================================================= # # END # # ========================================================================= #