diff --git a/README.md b/README.md index 76501eea..1c3234aa 100644 --- a/README.md +++ b/README.md @@ -227,13 +227,14 @@ from disent.frameworks.vae import BetaVae from disent.metrics import metric_dci, metric_mig from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.schedule import CyclicSchedule + # create the dataset & dataloaders -# - ToStandardisedTensor transforms images from numpy arrays to tensors and performs checks +# - ToImgTensorF32 transforms images from numpy arrays to tensors and performs checks data = XYObjectData() -dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToStandardisedTensor()) +dataset = DisentDataset(dataset=data, sampler=SingleSampler(), transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_workers=os.cpu_count()) # create the BetaVAE model @@ -261,7 +262,9 @@ module.register_schedule( # train model # - for 2048 batches/steps -trainer = pl.Trainer(max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False) +trainer = pl.Trainer( + max_steps=2048, gpus=1 if torch.cuda.is_available() else None, logger=False, checkpoint_callback=False +) trainer.fit(module, dataloader) # compute disentanglement metrics @@ -304,13 +307,13 @@ files (config options) in the subfolders (config groups) in ```yaml defaults: # system - - framework: adavae + - framework: adavae_os - model: vae_conv64 - optimizer: adam - schedule: none # data - dataset: xyobject - - dataset_sampling: full_bb + - sampling: default__bb - augment: none # runtime - metrics: fast diff --git a/disent/dataset/_base.py b/disent/dataset/_base.py index be1d17f4..fdb3f985 100644 --- a/disent/dataset/_base.py +++ b/disent/dataset/_base.py @@ -25,6 +25,7 @@ from functools import wraps from typing import Optional from typing import Sequence +from typing import TypeVar from typing import Union import numpy as np @@ -35,6 +36,7 @@ from disent.dataset.data import GroundTruthData from disent.dataset.sampling import SingleSampler from disent.dataset.wrapper import WrappedDataset +from disent.util.deprecate import deprecated from disent.util.iters import LengthIter from disent.util.math.random import random_choice_prng @@ -53,7 +55,10 @@ class NotGroundTruthDataError(Exception): """ -def groundtruth_only(func): +T = TypeVar('T') + + +def groundtruth_only(func: T) -> T: @wraps(func) def wrapper(self: 'DisentDataset', *args, **kwargs): if not self.is_ground_truth: @@ -76,8 +81,12 @@ def wrapper(self: 'DisentDataset', *args, **kwargs): # ========================================================================= # +_DO_COPY = object() + + class DisentDataset(Dataset, LengthIter): + def __init__( self, dataset: Union[Dataset, GroundTruthData], @@ -97,6 +106,20 @@ def __init__( if not self._sampler.is_init: self._sampler.init(dataset) + def shallow_copy( + self, + transform=_DO_COPY, + augment=_DO_COPY, + return_indices=_DO_COPY, + ) -> 'DisentDataset': + return DisentDataset( + dataset=self._dataset, + sampler=self._sampler, + transform=self._transform if (transform is _DO_COPY) else transform, + augment=self._augment if (augment is _DO_COPY) else augment, + return_indices=self._return_indices if (return_indices is _DO_COPY) else return_indices, + ) + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Properties # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @@ -114,6 +137,7 @@ def is_ground_truth(self) -> bool: return isinstance(self._dataset, GroundTruthData) @property + @deprecated('ground_truth_data property replaced with `gt_data`') @groundtruth_only def ground_truth_data(self) -> GroundTruthData: return self._dataset @@ -284,13 +308,13 @@ def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = Fals @groundtruth_only def dataset_batch_from_factors(self, factors: np.ndarray, mode: str): """Get a batch of observations X from a batch of factors Y.""" - indices = self.ground_truth_data.pos_to_idx(factors) + indices = self.gt_data.pos_to_idx(factors) return self.dataset_batch_from_indices(indices, mode=mode) @groundtruth_only def dataset_sample_batch_with_factors(self, num_samples: int, mode: str): """Sample a batch of observations X and factors Y.""" - factors = self.ground_truth_data.sample_factors(num_samples) + factors = self.gt_data.sample_factors(num_samples) batch = self.dataset_batch_from_factors(factors, mode=mode) return batch, default_collate(factors) diff --git a/disent/dataset/data/__init__.py b/disent/dataset/data/__init__.py index 86bbba5a..1a870327 100644 --- a/disent/dataset/data/__init__.py +++ b/disent/dataset/data/__init__.py @@ -50,3 +50,4 @@ # groundtruth -- impl synthetic from disent.dataset.data._groundtruth__xyobject import XYObjectData +from disent.dataset.data._groundtruth__xyobject import XYObjectShadedData diff --git a/disent/dataset/data/_episodes__custom.py b/disent/dataset/data/_episodes__custom.py index 37b0cf07..07bbc9f1 100644 --- a/disent/dataset/data/_episodes__custom.py +++ b/disent/dataset/data/_episodes__custom.py @@ -79,7 +79,7 @@ def _load_episode_observations(self) -> List[np.ndarray]: # check variables option_ids_to_names = {} ground_truth_keys = None - observation_shape = None + img_shape = None # load data episodes = [] for i, raw_episode in enumerate(raw_episodes): @@ -102,11 +102,11 @@ def _load_episode_observations(self) -> List[np.ndarray]: for gt_state in ground_truth_states: assert ground_truth_keys == gt_state.keys() # CHECK: observation shapes - if observation_shape is None: - observation_shape = observed_states[0].shape + if img_shape is None: + img_shape = observed_states[0].shape else: for observation in observed_states: - assert observation.shape == observation_shape + assert observation.shape == img_shape # APPEND: all observations into one long episode rollout.extend(observed_states) # cleanup unused memory! This is not ideal, but works well. diff --git a/disent/dataset/data/_groundtruth.py b/disent/dataset/data/_groundtruth.py index ae243623..0c269ba4 100644 --- a/disent/dataset/data/_groundtruth.py +++ b/disent/dataset/data/_groundtruth.py @@ -62,6 +62,10 @@ def __init__(self, transform=None): factor_names=self.factor_names, ) + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # Overridable Defaults # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + @property def name(self): name = self.__class__.__name__ @@ -70,7 +74,7 @@ def name(self): return name.lower() # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - # Overrides # + # State Space # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @property @@ -81,34 +85,22 @@ def factor_names(self) -> Tuple[str, ...]: def factor_sizes(self) -> Tuple[int, ...]: raise NotImplementedError() - @property - def observation_shape(self) -> Tuple[int, ...]: - # TODO: deprecate this! - # TODO: observation_shape should be called img_shape - # shape as would be for a non-batched observation - # eg. H x W x C - raise NotImplementedError() + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # Properties # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @property def x_shape(self) -> Tuple[int, ...]: - # TODO: deprecate this! - # TODO: x_shape should be called obs_shape # shape as would be for a single observation in a torch batch # eg. C x H x W - shape = self.observation_shape - return shape[-1], *shape[:-1] + H, W, C = self.img_shape + return (C, H, W) @property def img_shape(self) -> Tuple[int, ...]: # shape as would be for an original image # eg. H x W x C - return self.observation_shape - - @property - def obs_shape(self) -> Tuple[int, ...]: - # shape as would be for a single observation in a torch batch - # eg. C x H x W - return self.x_shape + raise NotImplementedError() @property def img_channels(self) -> int: @@ -116,6 +108,10 @@ def img_channels(self) -> int: assert channels in (1, 3), f'invalid number of channels for dataset: {self.__class__.__name__}, got: {repr(channels)}, required: 1 or 3' return channels + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # Overrides # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + def __getitem__(self, idx): obs = self._get_observation(idx) if self._transform is not None: @@ -149,23 +145,23 @@ def sample_random_obs_traversal(self, f_idx: int = None, base_factors=None, num: class ArrayGroundTruthData(GroundTruthData): - def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], array_chn_is_last: bool = True, observation_shape: Optional[Tuple[int, ...]] = None, transform=None): + def __init__(self, array, factor_names: Tuple[str, ...], factor_sizes: Tuple[int, ...], array_chn_is_last: bool = True, x_shape: Optional[Tuple[int, ...]] = None, transform=None): self.__factor_names = tuple(factor_names) self.__factor_sizes = tuple(factor_sizes) self._array = array # get shape - if observation_shape is not None: - C, H, W = observation_shape + if x_shape is not None: + C, H, W = x_shape elif array_chn_is_last: H, W, C = array.shape[1:] else: C, H, W = array.shape[1:] # set observation shape - self.__observation_shape = (H, W, C) + self.__img_shape = (H, W, C) # initialize super().__init__(transform=transform) # check shapes -- it is up to the user to handle which method they choose - assert (array.shape[1:] == self.img_shape) or (array.shape[1:] == self.obs_shape) + assert (array.shape[1:] == self.img_shape) or (array.shape[1:] == self.x_shape) @property def array(self): @@ -180,8 +176,8 @@ def factor_sizes(self) -> Tuple[int, ...]: return self.__factor_sizes @property - def observation_shape(self) -> Tuple[int, ...]: - return self.__observation_shape + def img_shape(self) -> Tuple[int, ...]: + return self.__img_shape def _get_observation(self, idx): # TODO: INVESTIGATE! I think this implements a lock, @@ -189,13 +185,14 @@ def _get_observation(self, idx): return self._array[idx] @classmethod - def new_like(cls, array, dataset: GroundTruthData, array_chn_is_last: bool = True): + def new_like(cls, array, gt_data: GroundTruthData, array_chn_is_last: bool = True): + # TODO: should this not copy the x_shape and transform? return cls( array=array, - factor_names=dataset.factor_names, - factor_sizes=dataset.factor_sizes, + factor_names=gt_data.factor_names, + factor_sizes=gt_data.factor_sizes, array_chn_is_last=array_chn_is_last, - observation_shape=None, # infer from array + x_shape=None, # infer from array transform=None, ) @@ -207,15 +204,12 @@ def new_like(cls, array, dataset: GroundTruthData, array_chn_is_last: bool = Tru # ========================================================================= # -class DiskGroundTruthData(GroundTruthData, metaclass=ABCMeta): +class _DiskDataMixin(object): - """ - Dataset that prepares a list DataObjects into some local directory. - - This directory can be - """ + # attr this class defines in _mixin_disk_init + _data_dir: str - def __init__(self, data_root: Optional[str] = None, prepare: bool = False, transform=None): - super().__init__(transform=transform) + def _mixin_disk_init(self, data_root: Optional[str] = None, prepare: bool = False): # get root data folder if data_root is None: data_root = self.default_data_root @@ -242,6 +236,23 @@ def default_data_root(self): def datafiles(self) -> Sequence[DataFile]: raise NotImplementedError + @property + def name(self) -> str: + raise NotImplementedError + + +class DiskGroundTruthData(_DiskDataMixin, GroundTruthData, metaclass=ABCMeta): + + """ + Dataset that prepares a list DataObjects into some local directory. + - This directory can be + """ + + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, transform=None): + super().__init__(transform=transform) + # get root data folder + self._mixin_disk_init(data_root=data_root, prepare=prepare) + class NumpyFileGroundTruthData(DiskGroundTruthData, metaclass=ABCMeta): """ @@ -282,7 +293,7 @@ def data_key(self) -> Optional[str]: class _Hdf5DataMixin(object): - # set attributes if _mixin_hdf5_init is called + # attrs this class defines in _mixin_hdf5_init _in_memory: bool _attrs: dict _data: Union[Hdf5Dataset, np.ndarray] @@ -303,11 +314,21 @@ def _mixin_hdf5_init(self, h5_path: str, h5_dataset_name: str = 'data', in_memor # indexing dataset objects returns numpy array # instantiating np.array from the dataset requires double memory. self._data = data[:] + self._data.flags.writeable = False data.close() else: # Load the dataset from the disk self._data = data + def __len__(self): + return len(self._data) + + @property + def img_shape(self): + shape = self._data.shape[1:] + assert len(shape) == 3 + return shape + # override from GroundTruthData def _get_observation(self, idx): return self._data[idx] @@ -353,7 +374,7 @@ def __init__(self, h5_path: str, in_memory=False, transform=None): self._attr_factor_sizes = tuple(int(size) for size in self._attrs['factor_sizes']) # set size (B, H, W, C) = self._data.shape - self._observation_shape = (H, W, C) + self._img_shape = (H, W, C) # initialize! super().__init__(transform=transform) @@ -370,8 +391,8 @@ def factor_sizes(self) -> Tuple[int, ...]: return self._attr_factor_sizes @property - def observation_shape(self) -> Tuple[int, ...]: - return self._observation_shape + def img_shape(self) -> Tuple[int, ...]: + return self._img_shape # ========================================================================= # diff --git a/disent/dataset/data/_groundtruth__cars3d.py b/disent/dataset/data/_groundtruth__cars3d.py index 096e6098..522bc749 100644 --- a/disent/dataset/data/_groundtruth__cars3d.py +++ b/disent/dataset/data/_groundtruth__cars3d.py @@ -109,7 +109,7 @@ class Cars3dData(NumpyFileGroundTruthData): factor_names = ('elevation', 'azimuth', 'object_type') factor_sizes = (4, 24, 183) # TOTAL: 17568 - observation_shape = (128, 128, 3) + img_shape = (128, 128, 3) datafile = DataFileCars3d( uri='http://www.scottreed.info/files/nips2015-analogy-data.tar.gz', diff --git a/disent/dataset/data/_groundtruth__dsprites.py b/disent/dataset/data/_groundtruth__dsprites.py index 6de593a2..99ee76ae 100644 --- a/disent/dataset/data/_groundtruth__dsprites.py +++ b/disent/dataset/data/_groundtruth__dsprites.py @@ -53,7 +53,7 @@ class DSpritesData(Hdf5GroundTruthData): # TODO: reference implementation has colour variants factor_names = ('shape', 'scale', 'orientation', 'position_x', 'position_y') factor_sizes = (3, 6, 40, 32, 32) # TOTAL: 737280 - observation_shape = (64, 64, 1) + img_shape = (64, 64, 1) datafile = DataFileHashedDlH5( # download file/link diff --git a/disent/dataset/data/_groundtruth__dsprites_imagenet.py b/disent/dataset/data/_groundtruth__dsprites_imagenet.py new file mode 100644 index 00000000..008aee28 --- /dev/null +++ b/disent/dataset/data/_groundtruth__dsprites_imagenet.py @@ -0,0 +1,286 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +import os +import shutil +from tempfile import TemporaryDirectory +from typing import Optional + +import numpy as np +import psutil +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +from torchvision.datasets import ImageFolder +from tqdm import tqdm + +from disent.dataset.data import GroundTruthData +from disent.dataset.data._groundtruth import _DiskDataMixin +from disent.dataset.data._groundtruth import _Hdf5DataMixin +from disent.dataset.data._groundtruth__dsprites import DSpritesData +from disent.dataset.transform import ToImgTensorF32 +from disent.dataset.util.datafile import DataFileHashedDlGen +from disent.dataset.util.hdf5 import H5Builder +from disent.dataset.util.stats import compute_data_mean_std +from disent.util.inout.files import AtomicSaveFile +from disent.util.iters import LengthIter +from disent.util.math.random import random_choice_prng + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# load imagenet-tiny data # +# ========================================================================= # + + +class NumpyFolder(ImageFolder): + def __getitem__(self, idx): + img, cls = super().__getitem__(idx) + return np.array(img) + + +def load_imagenet_tiny_data(raw_data_dir): + data = NumpyFolder(os.path.join(raw_data_dir, 'train')) + data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=lambda x: x) + # load data - this is a bit memory inefficient doing it like this instead of with a loop into a pre-allocated array + imgs = np.concatenate(list(tqdm(data, 'loading')), axis=0) + assert imgs.shape == (100_000, 64, 64, 3) + return imgs + + +def resave_imagenet_tiny_archive(orig_zipped_file, new_save_file, overwrite=False, h5_dataset_name: str = 'data'): + """ + Convert a imagenet tiny archive to an hdf5 or numpy file depending on the file extension. + Uncompressing the contents of the archive into a temporary directory in the same folder, + loading the images, then converting. + """ + _, ext = os.path.splitext(new_save_file) + assert ext in {'.npz', '.h5'}, f'unsupported save extension: {repr(ext)}, must be one of: {[".npz", ".h5"]}' + # extract zipfile into temp dir + with TemporaryDirectory(prefix='unzip_imagenet_tiny_', dir=os.path.dirname(orig_zipped_file)) as temp_dir: + log.info(f"Extracting into temporary directory: {temp_dir}") + shutil.unpack_archive(filename=orig_zipped_file, extract_dir=temp_dir) + images = load_imagenet_tiny_data(raw_data_dir=os.path.join(temp_dir, 'tiny-imagenet-200')) + # save the data + with AtomicSaveFile(new_save_file, overwrite=overwrite) as temp_file: + # check the mode + with H5Builder(temp_file, 'atomic_w') as builder: + builder.add_dataset_from_array( + name=h5_dataset_name, + array=images, + chunk_shape='batch', + compression_lvl=4, + attrs=None, + show_progress=True, + ) + + +# ========================================================================= # +# cars3d data object # +# ========================================================================= # + + +class ImageNetTinyDataFile(DataFileHashedDlGen): + """ + download the cars3d dataset and convert it to a hdf5 file. + """ + + dataset_name: str = 'data' + + def _generate(self, inp_file: str, out_file: str): + resave_imagenet_tiny_archive(orig_zipped_file=inp_file, new_save_file=out_file, overwrite=True, h5_dataset_name=self.dataset_name) + + +class ImageNetTinyData(_Hdf5DataMixin, _DiskDataMixin, Dataset, LengthIter): + + name = 'imagenet_tiny' + + datafile_imagenet_h5 = ImageNetTinyDataFile( + uri='http://cs231n.stanford.edu/tiny-imagenet-200.zip', + uri_hash={'fast': '4d97ff8efe3745a3bba9917d6d536559', 'full': '90528d7ca1a48142e341f4ef8d21d0de'}, + file_hash={'fast': '9c23e8ec658b1ec9f3a86afafbdbae51', 'full': '4c32b0b53f257ac04a3afb37e3a4204e'}, + uri_name='tiny-imagenet-200.zip', + file_name='tiny-imagenet-200.h5', + hash_mode='full' + ) + + datafiles = (datafile_imagenet_h5,) + + def __init__(self, data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None): + super().__init__() + self._transform = transform + # initialize mixin + self._mixin_disk_init( + data_root=data_root, + prepare=prepare, + ) + self._mixin_hdf5_init( + h5_path=os.path.join(self.data_dir, self.datafile_imagenet_h5.out_name), + h5_dataset_name=self.datafile_imagenet_h5.dataset_name, + in_memory=in_memory, + ) + + def __getitem__(self, idx: int): + obs = self._data[idx] + if self._transform is not None: + obs = self._transform(obs) + return obs + + +# ========================================================================= # +# dataset_dsprites # +# ========================================================================= # + + +class DSpritesImagenetData(GroundTruthData): + """ + DSprites that has imagenet images in the background. + """ + + name = 'dsprites_imagenet' + + # original dsprites it only (64, 64, 1) imagenet adds the colour channel + img_shape = (64, 64, 3) + factor_names = DSpritesData.factor_names + factor_sizes = DSpritesData.factor_sizes + + def __init__(self, visibility: int = 100, mode: str = 'bg', data_root: Optional[str] = None, prepare: bool = False, in_memory=False, transform=None): + super().__init__(transform=transform) + # check visibility and convert to ratio + assert isinstance(visibility, int), f'incorrect visibility percentage type, expected int, got: {type(visibility)}' + assert 0 <= visibility <= 100, f'incorrect visibility percentage: {repr(visibility)}, must be in range [0, 100]. ' + self._visibility = visibility / 100 + # check mode and convert to foreground boolean + assert mode in {'bg', 'fg'}, f'incorrect mode: {repr(mode)}, must be one of: ["bg", "fg"]' + self._foreground = (mode == 'fg') + # handle the datasets + self._dsprites = DSpritesData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None) + self._imagenet = ImageNetTinyData(data_root=data_root, prepare=prepare, in_memory=in_memory, transform=None) + # deterministic randomization of the imagenet order + self._imagenet_order = random_choice_prng( + len(self._imagenet), + size=len(self), + seed=42, + ) + + def _get_observation(self, idx): + # we need to combine the two dataset images + # dsprites contains only {0, 255} for values + # we can directly use these values to mask the imagenet image + bg = self._imagenet[self._imagenet_order[idx]] + fg = self._dsprites[idx].repeat(3, axis=-1) + # compute background + # set foreground + r = self._visibility + if self._foreground: + # lerp content to white, and then insert into fg regions + # r*bg + (1-r)*255 + obs = (r*bg + ((1-r)*255)).astype('uint8') + obs[fg <= 127] = 0 + else: + # lerp content to black, and then insert into bg regions + # r*bg + (1-r)*000 + obs = (r*bg).astype('uint8') + obs[fg > 127] = 255 + # checks + return obs + + +# ========================================================================= # +# STATS # +# ========================================================================= # + + +""" +dsprites_fg_1.0 + vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384] + vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281] +dsprites_fg_0.8 + vis_mean: [0.024956427531012196, 0.02336780403840578, 0.021475119672280243] + vis_std: [0.11864125016313823, 0.11137998105649799, 0.10281424917834255] +dsprites_fg_0.6 + vis_mean: [0.029335176871153983, 0.028145355435322966, 0.026731731769287146] + vis_std: [0.13663242436043319, 0.13114320478634894, 0.1246542727733097] +dsprites_fg_0.4 + vis_mean: [0.03369999506331255, 0.03290657349801835, 0.03196482946320608] + vis_std: [0.155514074438101, 0.1518464537731621, 0.14750944591836743] +dsprites_fg_0.2 + vis_mean: [0.038064750024334834, 0.03766780505193579, 0.03719798677641122] + vis_std: [0.17498878664096565, 0.17315570657628318, 0.1709923319496426] +dsprites_bg_1.0 + vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404] + vis_std: [0.2505510666843685, 0.25007259803668697, 0.2562415603123114] +dsprites_bg_0.8 + vis_mean: [0.40867981393820857, 0.38468564002021527, 0.34611573047508204] + vis_std: [0.22048328737091344, 0.22102216869942384, 0.22692977053753477] +dsprites_bg_0.6 + vis_mean: [0.31676960943447674, 0.29877166834408025, 0.2698556821388113] + vis_std: [0.19745897110349003, 0.1986606891520453, 0.203808842880044] +dsprites_bg_0.4 + vis_mean: [0.2248598986983768, 0.21285772298967615, 0.19359577132944206] + vis_std: [0.1841631708032332, 0.18554895825833284, 0.1893568926398198] +dsprites_bg_0.2 + vis_mean: [0.13294969414492142, 0.12694375140936273, 0.11733572285575933] + vis_std: [0.18311250427586276, 0.1840916474752131, 0.18607373519458442] +""" + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + + + def compute_all_stats(): + from disent.util.visualize.plot import plt_subplots_imshow + + def compute_stats(visibility, mode): + # plot images + data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode) + grid = np.array([data[i*24733] for i in np.arange(16)]).reshape([4, 4, *data.img_shape]) + plt_subplots_imshow(grid, show=True, title=f'{DSpritesImagenetData.name} visibility={repr(visibility)} mode={repr(mode)}') + # compute stats + name = f'dsprites_{mode}_{visibility}' + data = DSpritesImagenetData(prepare=True, visibility=visibility, mode=mode, transform=ToImgTensorF32()) + mean, std = compute_data_mean_std(data, batch_size=256, num_workers=min(psutil.cpu_count(logical=False), 64), progress=True) + print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}') + # return stats + return name, mean, std + + # compute common stats + stats = [] + for mode in ['fg', 'bg']: + for vis in [100, 80, 60, 40, 20]: + stats.append(compute_stats(vis, mode)) + + # print once at end + for name, mean, std in stats: + print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}') + + compute_all_stats() + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/dataset/data/_groundtruth__mpi3d.py b/disent/dataset/data/_groundtruth__mpi3d.py index c27e1836..b759acc7 100644 --- a/disent/dataset/data/_groundtruth__mpi3d.py +++ b/disent/dataset/data/_groundtruth__mpi3d.py @@ -44,17 +44,18 @@ class Mpi3dData(NumpyFileGroundTruthData): reference implementation: https://github.com/google-research/disentanglement_lib/blob/master/disentanglement_lib/data/ground_truth/mpi3d.py """ - name = 'mpi3d' - MPI3D_DATASETS = { - 'toy': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash=None), - 'realistic': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash=None), - 'real': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', uri_hash=None), + 'toy': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_toy.npz', uri_hash={'fast': '146138e36ff495e77ceacdc8cf14c37e', 'full': '55889cb7c7dfc655d6e0277beee88868'}), + 'realistic': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_realistic.npz', uri_hash={'fast': '96c8ff1155dd61f79d3493edef9f19e9', 'full': '59a6225b88b635365f70c91b3e52f70f'}), + 'real': DataFileHashedDl(uri='https://storage.googleapis.com/disentanglement_dataset/Final_Dataset/mpi3d_real.npz', uri_hash={'fast': 'e2941bba6f4a2b130edc5f364637b39e', 'full': '0f33f609918fb5c97996692f91129802'}), } factor_names = ('object_color', 'object_shape', 'object_size', 'camera_height', 'background_color', 'first_dof', 'second_dof') factor_sizes = (4, 4, 2, 3, 3, 40, 40) # TOTAL: 460800 - observation_shape = (64, 64, 3) + img_shape = (64, 64, 3) + + # override + data_key = 'images' def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subset='realistic', in_memory=False, transform=None): # check subset is correct @@ -72,6 +73,10 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, subse def datafile(self) -> DataFileHashedDl: return self.MPI3D_DATASETS[self._subset] + @property + def name(self) -> str: + return f'mpi3d_{self._subset}' + # ========================================================================= # # END # diff --git a/disent/dataset/data/_groundtruth__norb.py b/disent/dataset/data/_groundtruth__norb.py index 8dd74012..c96d0505 100644 --- a/disent/dataset/data/_groundtruth__norb.py +++ b/disent/dataset/data/_groundtruth__norb.py @@ -141,7 +141,7 @@ class SmallNorbData(DiskGroundTruthData): factor_names = ('category', 'instance', 'elevation', 'rotation', 'lighting') factor_sizes = (5, 5, 9, 18, 6) # TOTAL: 24300 - observation_shape = (96, 96, 1) + img_shape = (96, 96, 1) TRAIN_DATA_FILES = { 'dat': DataFileHashedDl(uri='https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat.gz', uri_hash={'fast': '92560cccc7bcbd6512805e435448b62d', 'full': '66054832f9accfe74a0f4c36a75bc0a2'}), diff --git a/disent/dataset/data/_groundtruth__shapes3d.py b/disent/dataset/data/_groundtruth__shapes3d.py index 0bcba018..444d44a1 100644 --- a/disent/dataset/data/_groundtruth__shapes3d.py +++ b/disent/dataset/data/_groundtruth__shapes3d.py @@ -44,11 +44,12 @@ class Shapes3dData(Hdf5GroundTruthData): info: https://console.cloud.google.com/storage/browser/_details/3d-shapes/3dshapes.h5 """ + # TODO: name should be `shapes3d` so that it is a valid python identifier name = '3dshapes' factor_names = ('floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation') factor_sizes = (10, 10, 10, 8, 4, 15) # TOTAL: 480000 - observation_shape = (64, 64, 3) + img_shape = (64, 64, 3) datafile = DataFileHashedDlH5( # download file/link @@ -59,7 +60,7 @@ class Shapes3dData(Hdf5GroundTruthData): # h5 re-save settings hdf5_dataset_name='images', hdf5_chunk_size=(1, 64, 64, 3), - hdf5_obs_shape=observation_shape, + hdf5_obs_shape=img_shape, ) diff --git a/disent/dataset/data/_groundtruth__xyobject.py b/disent/dataset/data/_groundtruth__xyobject.py index 84da50f1..9c7db0fe 100644 --- a/disent/dataset/data/_groundtruth__xyobject.py +++ b/disent/dataset/data/_groundtruth__xyobject.py @@ -22,6 +22,8 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import warnings +from typing import Optional from typing import Tuple import numpy as np @@ -30,7 +32,25 @@ # ========================================================================= # -# xy grid data # +# helper # +# ========================================================================= # + + +_R, _G, _B, _Y, _C, _M, _W = np.array([ + [255, 000, 000], [000, 255, 000], [000, 000, 255], # R, G, B + [255, 255, 000], [000, 255, 255], [255, 000, 255], # Y, C, M + [255, 255, 255], # white +]) + + +def _shades(num: int, shades): + all_shades = np.array([shade * i // num for i in range(1, num+1) for shade in np.array(shades)]) + assert all_shades.dtype in ('int64', 'int32') + return all_shades + + +# ========================================================================= # +# xy object data # # ========================================================================= # @@ -40,64 +60,40 @@ class XYObjectData(GroundTruthData): Dataset that generates all possible permutations of a square placed on a square grid, with varying scale and colour - - Does not seem to learn with a VAE when square size is equal to 1 - (This property may be explained in the paper "Understanding disentanglement in Beta-VAEs") + *NB* for most of these color palettes, there should be + an extra ground truth factor that represents shade. + We purposely leave this out to hinder disentanglement! It is subjective! """ + name = 'xy_object' + COLOR_PALETTES_1 = { - 'white': [ - [255], - ], - 'greys_halves': [ - [128], - [255], - ], - 'greys_quarters': [ - [64], - [128], - [192], - [255], - ], - 'colors': [ - [64], - [128], - [192], - [255], - ], + 'greys_1': _shades(1, [[255]]), + 'greys_2': _shades(2, [[255]]), + 'greys_4': _shades(4, [[255]]), + # aliases for greys so that we can just set `rgb=False` and it still works + 'rainbow_1': _shades(1, [[255]]), + 'rainbow_2': _shades(2, [[255]]), + 'rainbow_4': _shades(4, [[255]]), } COLOR_PALETTES_3 = { - 'white': [ - [255, 255, 255], - ], - 'greys_halves': [ - [128, 128, 128], - [255, 255, 255], - ], - 'greys_quarters': [ - [64, 64, 64], - [128, 128, 128], - [192, 192, 192], - [255, 255, 255], - ], - 'rgb': [ - [255, 000, 000], - [000, 255, 000], - [000, 000, 255], - ], - 'colors': [ - [255, 000, 000], [000, 255, 000], [000, 000, 255], - [255, 255, 000], [000, 255, 255], [255, 000, 255], - [255, 255, 255], - ], - 'colors_halves': [ - [128, 000, 000], [000, 128, 000], [000, 000, 128], - [128, 128, 000], [000, 128, 128], [128, 000, 128], - [128, 128, 128], - [255, 000, 000], [000, 255, 000], [000, 000, 255], - [255, 255, 000], [000, 255, 255], [255, 000, 255], - [255, 255, 255], - ], + # grey + 'greys_1': _shades(1, [_W]), + 'greys_2': _shades(2, [_W]), + 'greys_4': _shades(4, [_W]), + # colors -- white here and the incorrect ordering may throw off learning ground truth factors + 'colors_1': _shades(1, [_R, _G, _B, _Y, _C, _M, _W]), + 'colors_2': _shades(2, [_R, _G, _B, _Y, _C, _M, _W]), + 'colors_4': _shades(4, [_R, _G, _B, _Y, _C, _M, _W]), + # rgb + 'rgb_1': _shades(1, [_R, _G, _B]), + 'rgb_2': _shades(2, [_R, _G, _B]), + 'rgb_4': _shades(4, [_R, _G, _B]), + # rainbows -- these colors are mostly ordered correctly to align with gt factors + 'rainbow_1': _shades(1, [_R, _Y, _G, _C, _B, _M]), + 'rainbow_2': _shades(2, [_R, _Y, _G, _C, _B, _M]), + 'rainbow_4': _shades(4, [_R, _Y, _G, _C, _B, _M]), } factor_names = ('x', 'y', 'scale', 'color') @@ -107,16 +103,43 @@ def factor_sizes(self) -> Tuple[int, ...]: return self._placements, self._placements, len(self._square_scales), len(self._colors) @property - def observation_shape(self) -> Tuple[int, ...]: + def img_shape(self) -> Tuple[int, ...]: return self._width, self._width, (3 if self._rgb else 1) - def __init__(self, grid_size=64, grid_spacing=1, min_square_size=3, max_square_size=9, square_size_spacing=2, rgb=True, palette='colors', transform=None): + def __init__( + self, + grid_size: int = 64, + grid_spacing: int = 2, + min_square_size: int = 7, + max_square_size: int = 15, + square_size_spacing: int = 2, + rgb: bool = True, + palette: str = 'rainbow_4', + transform=None, + warn_: bool = True + ): + if warn_: + warnings.warn( + '`XYObjectData` defaults were changed in disent v0.3.0, if you want `approx` <= v0.2.x behavior then use the following parameters. Pallets also changed slightly too.' + '\n\tgrid_size=64' + '\n\tgrid_spacing=1' + '\n\tmin_square_size=3' + '\n\tmax_square_size=9' + '\n\tsquare_size_spacing=2' + '\n\trgb=True' + '\n\tpalette="colors_1"' + ) # generation self._rgb = rgb - if rgb: - self._colors = np.array(XYObjectData.COLOR_PALETTES_3[palette]) - else: - self._colors = np.array(XYObjectData.COLOR_PALETTES_1[palette]) + # check the pallete name + assert len(str.split(palette, '_')) == 2, f'palette name must follow format: `_`, got: {repr(palette)}' + # get the color palette + color_palettes = (XYObjectData.COLOR_PALETTES_3 if rgb else XYObjectData.COLOR_PALETTES_1) + if palette not in color_palettes: + raise KeyError(f'color palette: {repr(palette)} does not exist for rgb={repr(rgb)}, select one of: {sorted(color_palettes.keys())}') + self._colors = color_palettes[palette] + assert self._colors.ndim == 2 + assert self._colors.shape[-1] == (3 if rgb else 1) # image sizes self._width = grid_size # square scales @@ -134,11 +157,101 @@ def _get_observation(self, idx): r = (self._max_square_size - s) // 2 x, y = self._spacing*x + r, self._spacing*y + r # GENERATE - obs = np.zeros(self.observation_shape, dtype=np.uint8) + obs = np.zeros(self.img_shape, dtype=np.uint8) obs[y:y+s, x:x+s] = self._colors[c] return obs +class XYOldObjectData(XYObjectData): + + name = 'xy_object_shaded' + + def __init__(self, grid_size=64, grid_spacing=1, min_square_size=3, max_square_size=9, square_size_spacing=2, rgb=True, palette='colors', transform=None): + super().__init__( + grid_size=grid_size, + grid_spacing=grid_spacing, + min_square_size=min_square_size, + max_square_size=max_square_size, + square_size_spacing=square_size_spacing, + rgb=rgb, + palette=palette, + transform=transform, + ) + + +# ========================================================================= # +# END # +# ========================================================================= # + + +class XYObjectShadedData(XYObjectData): + """ + Dataset that generates all possible permutations of a square placed on a square grid, + with varying scale and colour + + - This is like `XYObjectData` but has an extra factor that represents the shade. + """ + + factor_names = ('x', 'y', 'scale', 'intensity', 'color') + + @property + def factor_sizes(self) -> Tuple[int, ...]: + return self._placements, self._placements, len(self._square_scales), self._brightness_levels, len(self._colors) + + @property + def img_shape(self) -> Tuple[int, ...]: + return self._width, self._width, (3 if self._rgb else 1) + + def __init__( + self, + grid_size: int = 64, + grid_spacing: int = 2, + min_square_size: int = 7, + max_square_size: int = 15, + square_size_spacing: int = 2, + rgb: bool = True, + palette: str = 'rainbow_4', + brightness_levels: Optional[int] = None, + transform=None, + ): + parts = palette.split('_') + if len(parts) > 1: + # extract num levels from the string + palette, b_levels = parts + b_levels = int(b_levels) + # handle conflict between brightness_levels and palette + if brightness_levels is None: + brightness_levels = b_levels + else: + warnings.warn(f'palette ends with brightness_levels integer: {repr(b_levels)} (ignoring) but actual brightness_levels parameter was already specified: {repr(brightness_levels)} (using)') + # check the brightness_levels + assert isinstance(brightness_levels, int), f'brightness_levels must be an integer, got: {type(brightness_levels)}' + assert 1 <= brightness_levels, f'brightness_levels must be >= 1, got: {repr(brightness_levels)}' + self._brightness_levels = brightness_levels + # initialize parent + super().__init__( + grid_size=grid_size, + grid_spacing=grid_spacing, + min_square_size=min_square_size, + max_square_size=max_square_size, + square_size_spacing=square_size_spacing, + rgb=rgb, + palette=f'{palette}_1', + transform=transform, + warn_=False, + ) + + def _get_observation(self, idx): + x, y, s, b, c = self.idx_to_pos(idx) + s = self._square_scales[s] + r = (self._max_square_size - s) // 2 + x, y = self._spacing*x + r, self._spacing*y + r + # GENERATE + obs = np.zeros(self.img_shape, dtype=np.uint8) + obs[y:y+s, x:x+s] = self._colors[c] * (b + 1) // self._brightness_levels + return obs + + # ========================================================================= # # END # # ========================================================================= # diff --git a/disent/nn/transform/__init__.py b/disent/dataset/transform/__init__.py similarity index 67% rename from disent/nn/transform/__init__.py rename to disent/dataset/transform/__init__.py index a67aea70..ca4aeded 100644 --- a/disent/nn/transform/__init__.py +++ b/disent/dataset/transform/__init__.py @@ -23,14 +23,17 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # transforms -from ._transforms import CheckTensor -from ._transforms import Noop -from ._transforms import ToStandardisedTensor +from disent.dataset.transform._transforms import CheckTensor +from disent.dataset.transform._transforms import Noop +from disent.dataset.transform._transforms import ToImgTensorF32 +from disent.dataset.transform._transforms import ToImgTensorU8 +from disent.dataset.transform._transforms import ToStandardisedTensor # deprecated +from disent.dataset.transform._transforms import ToUint8Tensor # deprecated # augments -from ._augment import FftGaussianBlur -from ._augment import FftBoxBlur -from ._augment import FftKernel +from disent.dataset.transform._augment import FftGaussianBlur +from disent.dataset.transform._augment import FftBoxBlur +from disent.dataset.transform._augment import FftKernel # disent dataset augment -from ._augment_groundtruth import DisentDatasetTransform +from disent.dataset.transform._augment_disent import DisentDatasetTransform diff --git a/disent/nn/transform/_augment.py b/disent/dataset/transform/_augment.py similarity index 100% rename from disent/nn/transform/_augment.py rename to disent/dataset/transform/_augment.py diff --git a/disent/nn/transform/_augment_groundtruth.py b/disent/dataset/transform/_augment_disent.py similarity index 100% rename from disent/nn/transform/_augment_groundtruth.py rename to disent/dataset/transform/_augment_disent.py diff --git a/disent/dataset/transform/_transforms.py b/disent/dataset/transform/_transforms.py new file mode 100644 index 00000000..8160e31d --- /dev/null +++ b/disent/dataset/transform/_transforms.py @@ -0,0 +1,155 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from typing import Optional +from typing import Sequence + +import torch +import disent.dataset.transform.functional as F_d + + +# ========================================================================= # +# Transforms # +# ========================================================================= # +from disent.util.deprecate import deprecated + + +class Noop(object): + """ + Transform that does absolutely nothing! + + See: disent.transform.functional.noop + """ + + def __call__(self, obs): + return obs + + def __repr__(self): + return f'{self.__class__.__name__}()' + + +class CheckTensor(object): + """ + Check that the data is a tensor, the right dtype, and in the required range. + + See: disent.transform.functional.check_tensor + """ + + def __init__( + self, + low: Optional[float] = 0., + high: Optional[float] = 1., + dtype: Optional[torch.dtype] = torch.float32, + ): + self._low = low + self._high = high + self._dtype = dtype + + def __call__(self, obs): + return F_d.check_tensor(obs, low=self._low, high=self._high, dtype=self._dtype) + + def __repr__(self): + kwargs = dict(low=self._low, high=self._high, dtype=self._dtype) + kwargs = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items() if (v is not None)) + return f'{self.__class__.__name__}({kwargs})' + + +class ToImgTensorF32(object): + """ + Basic transform that should be applied to most datasets, making sure + the image tensor is float32 and a specified size. + + Steps: + 1. resize image if size is specified + 2. if we have integer inputs, divide by 255 + 3. add missing channel to greyscale image + 4. move channels to first dim (H, W, C) -> (C, H, W) + 5. normalize using mean and std, values might thus be outside of the range [0, 1] + + See: disent.transform.functional.to_img_tensor_f32 + """ + + def __init__( + self, + size: Optional[F_d.SizeType] = None, + mean: Optional[Sequence[float]] = None, + std: Optional[Sequence[float]] = None, + ): + self._size = size + self._mean = tuple(mean) if (mean is not None) else None + self._std = tuple(std) if (std is not None) else None + + def __call__(self, obs) -> torch.Tensor: + return F_d.to_img_tensor_f32(obs, size=self._size, mean=self._mean, std=self._std) + + def __repr__(self): + kwargs = dict(size=self._size, mean=self._mean, std=self._std) + kwargs = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items() if (v is not None)) + return f'{self.__class__.__name__}({kwargs})' + + +class ToImgTensorU8(object): + """ + Basic transform that makes sure the image tensor is uint8 and a specified size. + + Steps: + 1. resize image if size is specified + 2. add missing channel to greyscale image + 3. move channels to first dim (H, W, C) -> (C, H, W) + + See: disent.transform.functional.to_img_tensor_u8 + """ + + def __init__( + self, + size: Optional[F_d.SizeType] = None, + ): + self._size = size + + def __call__(self, obs) -> torch.Tensor: + return F_d.to_img_tensor_u8(obs, size=self._size) + + def __repr__(self): + kwargs = f'size={repr(self._size)}' if (self._size is not None) else '' + return f'{self.__class__.__name__}({kwargs})' + + +# ========================================================================= # +# Deprecated # +# ========================================================================= # + + +@deprecated('ToStandardisedTensor renamed to ToImgTensorF32') +class ToStandardisedTensor(ToImgTensorF32): + pass + + +@deprecated('ToUint8Tensor renamed to ToImgTensorU8') +class ToUint8Tensor(ToImgTensorU8): + pass + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/dataset/transform/functional.py b/disent/dataset/transform/functional.py new file mode 100644 index 00000000..e3ebf447 --- /dev/null +++ b/disent/dataset/transform/functional.py @@ -0,0 +1,279 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypeVar +from typing import Union + +import numpy as np +from PIL.Image import Image +import torch +import torchvision.transforms.functional as F_tv + + +# ========================================================================= # +# Types # +# ========================================================================= # + + +_T = TypeVar('_T') + +Obs = Union[np.ndarray, Image] + +SizeType = Union[int, Tuple[int, int]] + + +# ========================================================================= # +# Functional Transforms # +# ========================================================================= # + + +def noop(obs: _T) -> _T: + """ + Transform that does absolutely nothing! + """ + return obs + + +def check_tensor( + obs: Any, + low: Optional[float] = 0., + high: Optional[float] = 1., + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Check that the input is a tensor, its datatype matches, and + that it is in the required range. + """ + # check is a tensor + assert torch.is_tensor(obs), 'observation is not a tensor' + # check type + if dtype is not None: + assert obs.dtype == dtype, f'tensor type {obs.dtype} is not required type {dtype}' + # check range + if low is not None: + assert low <= obs.min(), f'minimum value of tensor {obs.min()} is less than allowed minimum value: {low}' + if high is not None: + assert obs.max() <= high, f'maximum value of tensor {obs.max()} is greater than allowed maximum value: {high}' + # DONE! + return obs + + +# ========================================================================= # +# Normalized Image Tensors # +# ========================================================================= # + + +def to_img_tensor_u8( + obs: Obs, + size: Optional[SizeType] = None, +) -> torch.Tensor: + """ + Basic transform that makes sure the image tensor is uint8 and a specified size. + + Steps: + 1. resize image if size is specified + 2. add missing channel to greyscale image + 3. move channels to first dim (H, W, C) -> (C, H, W) + """ + # resize image + if size is not None: + if not isinstance(obs, Image): + obs = F_tv.to_pil_image(obs) + obs = F_tv.resize(obs, size=size) + # to numpy + if isinstance(obs, Image): + obs = np.array(obs) + # add missing axis + if obs.ndim == 2: + obs = obs[:, :, None] + # to tensor & move axis + obs = torch.from_numpy(obs) + obs = torch.moveaxis(obs, -1, -3) + # checks + assert obs.ndim == 3 + assert obs.dtype == torch.uint8 + # done! + return obs + + +def to_img_tensor_f32( + obs: Obs, + size: Optional[SizeType] = None, + mean: Optional[Sequence[float]] = None, + std: Optional[Sequence[float]] = None, +) -> torch.Tensor: + """ + Basic transform that should be applied to most datasets, making sure + the image tensor is float32 and a specified size. + + Steps: + 1. resize image if size is specified + 2. if we have integer inputs, divide by 255 + 3. add missing channel to greyscale image + 4. move channels to first dim (H, W, C) -> (C, H, W) + 5. normalize using mean and std, values might thus be outside of the range [0, 1] + """ + # resize image + if size is not None: + if not isinstance(obs, Image): + obs = F_tv.to_pil_image(obs) + obs = F_tv.resize(obs, size=size) + # transform to tensor, add missing dims & move channel dim to front + obs = F_tv.to_tensor(obs) + # checks + assert obs.ndim == 3, f'obs has does not have 3 dimensions, got: {obs.ndim} for shape: {obs.shape}' + assert obs.dtype == torch.float32, f'obs is not dtype torch.float32, got: {obs.dtype}' + # apply mean and std, we obs is of the shape (C, H, W) + if (mean is not None) or (std is not None): + obs = F_tv.normalize(obs, mean=0. if (mean is None) else mean, std=1. if (std is None) else std, inplace=True) + assert obs.dtype == torch.float32, f'after normalization, tensor should remain as dtype torch.float32, got: {obs.dtype}' + # done! + return obs + + +# ========================================================================= # +# Custom Normalized Image - Faster Than Above # +# ========================================================================= # + + +# def to_img_tensor_f32( +# x: Obs, +# size: Optional[SizeType] = None, +# channel_to_front: bool = None, +# ): +# """ +# Basic transform that should be applied to +# any dataset before augmentation. +# +# 1. resize if size is specified +# 2. if needed convert integers to float32 by dividing by 255 +# 3. normalize using mean and std, values might thus be outside of the range [0, 1] +# +# Convert PIL or uint8 inputs to float32 +# - input images should always have 2 (H, W) or 3 channels (H, W, C) +# - output image always has size (C, H, W) with channels moved to the first dim +# """ +# return _to_img_tensor(x, size=size, channel_to_front=channel_to_front, to_float32=True) +# +# +# def to_img_tensor_u8( +# x: Obs, +# size: Optional[SizeType] = None, +# channel_to_front: bool = None, +# ): +# """ +# Convert PIL or uint8 inputs to float32 +# - input images should always have 2 (H, W) or 3 channels (H, W, C) +# - output image always has size (C, H, W) with channels moved to the first dim +# """ +# return _to_img_tensor(x, size=size, channel_to_front=channel_to_front, to_float32=False) +# +# +# def _to_img_tensor( +# x: Obs, +# size: Optional[SizeType] = None, +# channel_to_front: bool = None, +# to_float32: bool = True, +# ) -> torch.Tensor: +# assert isinstance(x, (np.ndarray, Image)), f'input is not an numpy.ndarray or PIL.Image.Image, got: {type(x)}' +# # optionally resize the image, returns a numpy array or a PIL.Image.Image +# x = _resize_if_needed(x, size=size) +# # convert image to numpy +# if isinstance(x, Image): +# x = np.array(x) +# # make sure 2D becomes 3D +# if x.ndim == 2: +# x = x[:, :, None] +# assert x.ndim == 3, f'obs has invalid number of dimensions, required 2 or 3, got: {x.ndim} for shape: {x.shape}' +# # convert to float32 if int or uint +# if to_float32: +# if x.dtype.kind in ('i', 'u'): +# x = x.astype('float32') / 255 # faster than with torch +# # convert to torch.Tensor and move channels (H, W, C) -> (C, H, W) +# x = torch.from_numpy(x) +# if channel_to_front or (channel_to_front is None): +# x = torch.moveaxis(x, -1, 0) # faster than the numpy version +# # final check +# if to_float32: +# assert x.dtype == torch.float32, f'obs dtype invalid, required: {torch.float32}, got: {x.dtype}' +# else: +# assert x.dtype == torch.uint8, f'obs dtype invalid, required: {torch.uint8}, got: {x.dtype}' +# # done +# return x +# +# +# # ========================================================================= # +# # Resize Image Helper # +# # ========================================================================= # +# +# +# _PIL_INTERPOLATE_MODES = { +# 'nearest': 0, +# 'lanczos': 1, +# 'bilinear': 2, +# 'bicubic': 3, +# 'box': 4, +# 'hamming': 5, +# } +# +# +# def _resize_if_needed(img: Union[np.ndarray, Image], size: Optional[Union[Tuple[int, int], int]] = None) -> Union[np.ndarray, Image]: +# # skip resizing +# if size is None: +# return img +# # normalize size +# if isinstance(size, int): +# size = (size, size) +# # get current image size +# if isinstance(img, Image): +# in_size = (img.height, img.width) +# else: +# assert img.ndim in (2, 3), f'image must have 2 or 3 dims, got shape: {img.shape}' +# in_size = img.shape[:2] +# # skip if the same size +# if in_size == size: +# return img +# # normalize the image +# if not isinstance(img, Image): +# assert img.dtype == 'uint8' +# # normalize image +# if img.ndim == 3: +# c = img.shape[-1] +# assert c in (1, 3), f'image channel dim must be of size 1 or 3, got shape: {img.shape}' +# img, mode = (img, 'RGB') if (c == 3) else (img[:, :, 0], 'L') +# else: +# mode = 'L' +# # convert +# img = PIL.Image.fromarray(img, mode=mode) +# # resize +# return img.resize(size, resample=_PIL_INTERPOLATE_MODES['bilinear']) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/dataset/util/hdf5.py b/disent/dataset/util/hdf5.py index a5ba0973..eec1058b 100644 --- a/disent/dataset/util/hdf5.py +++ b/disent/dataset/util/hdf5.py @@ -29,6 +29,7 @@ import contextlib import logging import os +from pathlib import Path from typing import Any from typing import Callable from typing import Dict @@ -44,6 +45,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from disent.util.deprecate import deprecated from disent.util.strings import colors as c from disent.util.inout.files import AtomicSaveFile from disent.util.iters import iter_chunks @@ -129,10 +131,22 @@ def h5_assert_deterministic(h5_file: h5py.File) -> h5py.File: @contextlib.contextmanager def h5_open(path: str, mode: str = 'r') -> h5py.File: - assert str.endswith(path, '.h5'), f'hdf5 file path does not end with extension: `.h5`' + """ + MODES: + atomic_w Create temp file, then move and overwrite existing when done + atomic_x Create temp file, then try move or fail if existing when done + r Readonly, file must exist (default) + r+ Read/write, file must exist + w Create file, truncate if exists + w- or x Create file, fail if exists + a Read/write if exists, create otherwise + """ + assert str.endswith(path, '.h5') or str.endswith(path, '.hdf5'), f'hdf5 file path does not end with extension: `.h5` or `.hdf5`, got: {path}' # get atomic context manager if mode == 'atomic_w': save_context, mode = AtomicSaveFile(path, open_mode=None, overwrite=True), 'w' + elif mode == 'atomic_x': + save_context, mode = AtomicSaveFile(path, open_mode=None, overwrite=False), 'x' else: save_context = contextlib.nullcontext(path) # handle saving to file @@ -143,13 +157,33 @@ def h5_open(path: str, mode: str = 'r') -> h5py.File: class H5Builder(object): - def __init__(self, h5_file: h5py.File): + def __init__(self, path: Union[str, Path], mode: str = 'x'): super().__init__() # make sure that the file is deterministic # - we might be missing some of the properties that control this # - should we add a recursive option? - h5_assert_deterministic(h5_file) - self._h5_file = h5_file + if not isinstance(path, (str, Path)): + raise TypeError(f'the given h5py path must be of type: `str`, `pathlib.Path`, got: {type(path)}') + self._h5_path = path + self._h5_mode = mode + self._context_manager = None + self._open_file = None + + def __enter__(self): + self._context_manager = h5_open(self._h5_path, self._h5_mode) + self._open_file = h5_assert_deterministic(self._context_manager.__enter__()) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._context_manager.__exit__(exc_type, exc_val, exc_tb) + self._open_file = None + self._context_manager = None + + @property + def _h5_file(self) -> h5py.File: + if self._open_file is None: + raise 'The H5Builder has not been opened in a new context, use `with H5Builder(...) as builder: ...`' + return self._open_file def add_dataset( self, @@ -217,7 +251,7 @@ def fill_dataset( # loop variables n = len(dataset) # save data - with tqdm(total=n, disable=not show_progress) as progress: + with tqdm(total=n, disable=not show_progress, desc=f'saving {name}') as progress: for i in range(0, n, batch_size): j = min(i + batch_size, n) assert j > i, f'this is a bug! {repr(j)} > {repr(i)}, len(dataset)={repr(n)}, batch_size={repr(batch_size)}' @@ -265,6 +299,10 @@ def get_batch_fn(i, j): batch = mutator(batch) return np.array(batch) + # get the batch size + if batch_size == 'auto' and isinstance(array, h5py.Dataset): + batch_size = array.chunks[0] + # copy into the dataset self.fill_dataset( name=name, @@ -302,6 +340,36 @@ def get_batch_fn(i, j): ) return self + def add_dataset_from_array( + self, + name: str, + array: np.ndarray, + chunk_shape: ChunksType = 'batch', + compression_lvl: Optional[int] = 4, + attrs: Optional[Dict[str, Any]] = None, + batch_size: Union[int, Literal['auto']] = 'auto', + show_progress: bool = False, + # optional, discovered automatically from array otherwise + mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, + dtype: Optional[np.dtype] = None, + shape: Optional[Tuple[int, ...]] = None, + ): + self.add_dataset( + name=name, + shape=array.shape if (shape is None) else shape, + dtype=array.dtype if (dtype is None) else dtype, + chunk_shape=chunk_shape, + compression_lvl=compression_lvl, + attrs=attrs, + ) + self.fill_dataset_from_array( + name=name, + array=array, + batch_size=batch_size, + show_progress=show_progress, + mutator=mutator, + ) + def add_dataset_from_gt_data( self, data: Union['DisentDataset', 'GroundTruthData'], @@ -309,7 +377,7 @@ def add_dataset_from_gt_data( img_shape: Tuple[Optional[int], ...] = (None, None, None), # None items are automatically found batch_size: int = 32, compression_lvl: Optional[int] = 9, - num_workers=min(os.cpu_count(), 16), + num_workers: int = min(os.cpu_count(), 16), show_progress: bool = True, dtype: str = 'uint8', attrs: Optional[dict] = None @@ -353,6 +421,65 @@ def add_dataset_from_gt_data( mutator=mutator, ) +# def resave_dataset(self, +# name: str, +# inp: Union[str, Path, h5py.File, h5py.Dataset, np.ndarray], +# # h5 re-save settings +# chunk_shape: ChunksType = 'batch', +# compression_lvl: Optional[int] = 4, +# attrs: Optional[Dict[str, Any]] = None, +# batch_size: Union[int, Literal['auto']] = 'auto', +# show_progress: bool = False, +# # optional, discovered automatically from array otherwise +# mutator: Optional[Callable[[np.ndarray], np.ndarray]] = None, +# dtype: Optional[np.dtype] = None, +# obs_shape: Optional[Tuple[int, ...]] = None, +# ): +# # TODO: should this be more general and be able to handle add_dataset_from_gt_data too? +# # TODO: this is very similar to save dataset below! +# with _get_array_context(inp, name) as arr: +# self.add_dataset_from_array( +# name=name, +# array=arr, +# chunk_shape=chunk_shape, +# compression_lvl=compression_lvl, +# attrs=attrs, +# batch_size=batch_size, +# show_progress=show_progress, +# mutator=mutator, +# dtype=dtype, +# shape=(len(arr), *obs_shape) if obs_shape else None, +# ) +# +# +# @contextlib.contextmanager +# def _get_array_context( +# inp: Union[str, Path, h5py.File, h5py.Dataset, np.ndarray], +# dataset_name: str = None, +# ) -> Union[h5py.Dataset, np.ndarray]: +# # check the inputs +# if not isinstance(inp, (str, Path, h5py.File, h5py.Dataset, np.ndarray)): +# raise TypeError(f'unsupported input type: {type(inp)}') +# # handle loading files +# if isinstance(inp, str): +# _, ext = os.path.splitext(inp) +# if ext in ('.h5', '.hdf5'): +# inp_context = h5py.File(inp, 'r') +# else: +# raise ValueError(f'unsupported extension: {repr(ext)} for path: {repr(inp)}') +# else: +# import contextlib +# inp_context = contextlib.nullcontext(inp) +# # re-save datasets +# with inp_context as inp_data: +# # get input dataset from h5 file +# if isinstance(inp_data, h5py.File): +# if dataset_name is None: +# raise ValueError('dataset_name must be specified if the input is an h5py.File so we can retrieve a h5py.Dataset') +# inp_data = inp_data[dataset_name] +# # return the data +# yield inp_data + # ========================================================================= # # hdf5 - resave # diff --git a/disent/dataset/util/state_space.py b/disent/dataset/util/state_space.py index 9186846e..5c72168f 100644 --- a/disent/dataset/util/state_space.py +++ b/disent/dataset/util/state_space.py @@ -21,16 +21,26 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + from functools import lru_cache from typing import Optional from typing import Sequence from typing import Tuple +from typing import Union import numpy as np from disent.util.iters import LengthIter from disent.util.visualize.vis_util import get_idx_traversal +# ========================================================================= # +# Types # +# ========================================================================= # + + +NonNormalisedFactors = Union[Sequence[Union[int, str]], Union[int, str]] + + # ========================================================================= # # Basic State Space # # ========================================================================= # @@ -86,6 +96,38 @@ def factor_names(self) -> Tuple[str, ...]: """A list of names of factors handled by this state space""" return self.__factor_names + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # Factor Helpers # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + def normalise_factor_idx(self, factor: Union[int, str]) -> int: + # convert a factor name to the factor id + if isinstance(factor, str): + try: + f_idx = self.factor_names.index(factor) + except: + raise KeyError(f'invalid factor name: {repr(factor)} must be one of: {self.factor_names}') + else: + f_idx = int(factor) + # check that the values are correct + assert isinstance(f_idx, int) + assert 0 <= f_idx < self.num_factors + # return the resulting values + return f_idx + + def normalise_factor_idxs(self, factors: 'NonNormalisedFactors') -> np.ndarray: + # return the default list of factor indices + if factors is None: + return np.arange(self.num_factors) + # normalize a single factor into a list + if isinstance(factors, (int, str)): + factors = [factors] + # convert all the factors to their indices + factors = np.array([self.normalise_factor_idx(factor) for factor in factors]) + # done! make sure there are not duplicates! + assert len(set(factors)) == len(factors), 'duplicate factors were found!' + return factors + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Coordinate Transform - any dim array, only last axis counts! # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py new file mode 100644 index 00000000..c08f66a3 --- /dev/null +++ b/disent/dataset/util/stats.py @@ -0,0 +1,166 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import os +from typing import Tuple + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from disent.util.function import wrapped_partial + + +# ========================================================================= # +# COMPUTE DATASET STATS # +# ========================================================================= # + + +@torch.no_grad() +def compute_data_mean_std( + data, + batch_size: int = 256, + num_workers: int = min(os.cpu_count(), 16), + progress: bool = False, + chn_is_last: bool = False +) -> Tuple[np.ndarray, np.ndarray]: + """ + Input data when collected using a DataLoader should return + `torch.Tensor`s, output mean and std are an `np.ndarray`s + """ + loader = DataLoader( + data, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + drop_last=False, + ) + if progress: + from tqdm import tqdm + loader = tqdm(loader, desc=f'{data.__class__.__name__} stats', total=(len(data) + batch_size - 1) // batch_size) + # reduction dims + dims = (1, 2) if chn_is_last else (2, 3) + # collect obs means & stds + img_means, img_stds = [], [] + for batch in loader: + assert isinstance(batch, torch.Tensor), f'batch must be an instance of torch.Tensor, got: {type(batch)}' + assert batch.ndim == 4, f'batch shape must be: (B, C, H, W), got: {tuple(batch.shape)}' + batch = batch.to(torch.float64) + img_means.append(torch.mean(batch, dim=dims)) + img_stds.append(torch.std(batch, dim=dims)) + # aggregate obs means & stds + mean = torch.mean(torch.cat(img_means, dim=0), dim=0) + std = torch.mean(torch.cat(img_stds, dim=0), dim=0) + # checks! + assert mean.ndim == 1 + assert std.ndim == 1 + # done! + return mean.numpy(), std.numpy() + + +# ========================================================================= # +# HELPER # +# ========================================================================= # + + +if __name__ == '__main__': + + def main(progress=False): + from disent.dataset import data + from disent.dataset.transform import ToImgTensorF32 + + for data_cls in [ + # groundtruth -- impl + data.Cars3dData, + data.DSpritesData, + data.SmallNorbData, + data.Shapes3dData, + wrapped_partial(data.Mpi3dData, subset='toy', in_memory=True), + wrapped_partial(data.Mpi3dData, subset='realistic', in_memory=True), + wrapped_partial(data.Mpi3dData, subset='real', in_memory=True), + # groundtruth -- impl synthetic + data.XYObjectData, + data.XYObjectShadedData, + ]: + from disent.dataset.transform import ToImgTensorF32 + # Most common standardized way of computing the mean and std over observations + # resized to 64px in size of dtype float32 in the range [0, 1]. + data = data_cls(transform=ToImgTensorF32(size=64)) + mean, std = compute_data_mean_std(data, progress=progress) + # results! + print(f'{data.__class__.__name__} - {data.name}:\n mean: {mean.tolist()}\n std: {std.tolist()}') + + # RUN! + main() + + +# ========================================================================= # +# RESULTS: 2021-10-12 # +# ========================================================================= # + + +# Cars3dData - cars3d: +# mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] +# std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] +# DSpritesData - dsprites: +# mean: [0.042494423521889584] +# std: [0.19516645880626055] +# SmallNorbData - smallnorb: +# mean: [0.7520918401088603] +# std: [0.09563879016827262] +# Shapes3dData - 3dshapes: +# mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] +# std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] +# Mpi3dData - mpi3d_toy: +# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702] +# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426] +# Mpi3dData - mpi3d_realistic: +# mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256] +# std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177] +# Mpi3dData - mpi3d_real: +# mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627] +# std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629] +# XYBlocksData - xyblocks: +# mean: [0.10040509259259259, 0.10040509259259259, 0.10040509259259259] +# std: [0.21689087652106678, 0.21689087652106676, 0.21689087652106678] +# XYObjectData - xy_object: +# mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] +# std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] +# XYObjectShadedData - xy_object: +# mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] +# std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] +# XYSquaresData - xy_squares: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] +# XYSquaresMinimalData - xy_squares: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] +# XColumnsData - x_columns: +# mean: [0.125, 0.125, 0.125] +# std: [0.33075929223788925, 0.3307592922378891, 0.3307592922378892] + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/frameworks/_framework.py b/disent/frameworks/_framework.py index d006a91b..83dd85b3 100644 --- a/disent/frameworks/_framework.py +++ b/disent/frameworks/_framework.py @@ -96,11 +96,11 @@ def __init__( super().__init__(cfg=cfg) # get the optimizer if isinstance(self.cfg.optimizer, str): - if self.cfg.optimizer not in registry.OPTIMIZER: + if self.cfg.optimizer not in registry.OPTIMIZERS: raise KeyError(f'invalid optimizer: {repr(self.cfg.optimizer)}, valid optimizers are: {sorted(registry.OPTIMIZER)}, otherwise pass a torch.optim.Optimizer class instead.') - self.cfg.optimizer = registry.OPTIMIZER[self.cfg.optimizer] + self.cfg.optimizer = registry.OPTIMIZERS[self.cfg.optimizer] # check the optimizer values - assert isinstance(self.cfg.optimizer, type) and issubclass(self.cfg.optimizer, torch.optim.Optimizer) and (self.cfg.optimizer != torch.optim.Optimizer) + assert callable(self.cfg.optimizer) assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(self.cfg.optimizer_kwargs)}' # set default values for optimizer if self.cfg.optimizer_kwargs is None: @@ -119,14 +119,17 @@ def __init__( @final def configure_optimizers(self): - optimizer = self.cfg.optimizer - # instantiate the optimizer! - if issubclass(optimizer, torch.optim.Optimizer): - optimizer = optimizer(self.parameters(), **self.cfg.optimizer_kwargs) - elif not isinstance(optimizer, torch.optim.Optimizer): - raise TypeError(f'unsupported optimizer type: {type(optimizer)}') + optimizer_cls = self.cfg.optimizer + # check that we can call the optimizer + if not callable(optimizer_cls): + raise TypeError(f'unsupported optimizer type: {type(optimizer_cls)}') + # instantiate class + optimizer_instance = optimizer_cls(self.parameters(), **self.cfg.optimizer_kwargs) + # check instance + if not isinstance(optimizer_instance, torch.optim.Optimizer): + raise TypeError(f'returned object is not an instance of torch.optim.Optimizer, got: {type(optimizer_instance)}') # return the optimizer - return optimizer + return optimizer_instance @final def training_step(self, batch, batch_idx): diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index dedb9e2f..8e7d2541 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -112,7 +112,8 @@ def recon_handler(self) -> ReconLossHandler: def _get_xs_and_targs(self, batch: Dict[str, Tuple[torch.Tensor, ...]], batch_idx) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]: xs_targ = batch['x_targ'] if 'x' not in batch: - warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ') + # TODO: re-enable this warning but only ever print once! + # warnings.warn('dataset does not have input: x -> x_targ using target as input: x_targ -> x_targ') xs = xs_targ else: xs = batch['x'] diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 63047f79..9821d15e 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -38,7 +38,7 @@ from disent.nn.loss.reduction import batch_loss_reduction from disent.nn.loss.reduction import loss_reduction from disent.nn.modules import DisentModule -from disent.nn.transform import FftKernel +from disent.dataset.transform import FftKernel # ========================================================================= # @@ -144,14 +144,8 @@ class ReconLossHandlerMse(ReconLossHandler): """ def activate(self, x_partial: torch.Tensor) -> torch.Tensor: - # we allow the model output x to generally be in the range [-1, 1] and scale - # it to the range [0, 1] here to match the targets. - # - this lets it learn more easily as the output is naturally centered on 1 - # - doing this here directly on the output is easier for visualisation etc. - # - TODO: the better alternative is that we rather calculate the MEAN and STD over the dataset - # and normalise that. - # - sigmoid is numerically not suitable with MSE - return (x_partial + 1) / 2 + # mse requires no final activation + return x_partial def compute_unreduced_loss(self, x_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor: return F.mse_loss(x_recon, x_targ, reduction='none') @@ -168,10 +162,6 @@ def compute_unreduced_loss(self, x_recon, x_targ): return torch.abs(x_recon - x_targ) - - - - class ReconLossHandlerBce(ReconLossHandler): """ BCE loss should only be used with binary targets {0, 1}. @@ -291,9 +281,9 @@ def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_t # NOTE: this function compliments make_kernel in transform/_augment.py def make_reconstruction_loss(name: str, reduction: str) -> ReconLossHandler: - if name in registry.RECON_LOSS: + if name in registry.RECON_LOSSES: # search normal losses! - return registry.RECON_LOSS[name](reduction) + return registry.RECON_LOSSES[name](reduction) else: # regex search losses, and call with args! for r, _, fn in _ARG_RECON_LOSSES: @@ -301,7 +291,7 @@ def make_reconstruction_loss(name: str, reduction: str) -> ReconLossHandler: if result is not None: return fn(reduction, *result.groups()) # we couldn't find anything - raise KeyError(f'Invalid vae reconstruction loss: {repr(name)} Valid losses include: {sorted(registry.RECON_LOSS)}, examples of additional argument based losses include: {[example for _, example, _ in _ARG_RECON_LOSSES]}') + raise KeyError(f'Invalid vae reconstruction loss: {repr(name)} Valid losses include: {sorted(registry.RECON_LOSSES)}, examples of additional argument based losses include: {[example for _, example, _ in _ARG_RECON_LOSSES]}') # ========================================================================= # diff --git a/disent/frameworks/vae/__init__.py b/disent/frameworks/vae/__init__.py index f7ac9ef2..08a5ffe1 100644 --- a/disent/frameworks/vae/__init__.py +++ b/disent/frameworks/vae/__init__.py @@ -35,3 +35,4 @@ # weakly supervised frameworks from disent.frameworks.vae._weaklysupervised__adavae import AdaVae +from disent.frameworks.vae._weaklysupervised__adavae import AdaGVaeMinimal diff --git a/disent/frameworks/vae/_unsupervised__dfcvae.py b/disent/frameworks/vae/_unsupervised__dfcvae.py index 1bc93b0f..a44cfefe 100644 --- a/disent/frameworks/vae/_unsupervised__dfcvae.py +++ b/disent/frameworks/vae/_unsupervised__dfcvae.py @@ -41,7 +41,7 @@ from disent.frameworks.helper.util import compute_ave_loss from disent.frameworks.vae._unsupervised__betavae import BetaVae from disent.nn.loss.reduction import get_mean_loss_scale -from disent.nn.transform.functional import check_tensor +from disent.dataset.transform.functional import check_tensor # ========================================================================= # diff --git a/disent/frameworks/vae/_weaklysupervised__adavae.py b/disent/frameworks/vae/_weaklysupervised__adavae.py index f2c64357..4434711d 100644 --- a/disent/frameworks/vae/_weaklysupervised__adavae.py +++ b/disent/frameworks/vae/_weaklysupervised__adavae.py @@ -30,6 +30,7 @@ import torch from dataclasses import dataclass from torch.distributions import Distribution +from torch.distributions import kl_divergence from torch.distributions import Normal from disent.frameworks.vae._unsupervised__betavae import BetaVae @@ -50,6 +51,11 @@ class AdaVae(BetaVae): MODIFICATION: - Symmetric KL Calculation used by default, described in: https://arxiv.org/pdf/2010.14407.pdf - adjustable threshold value + + * This class is a little over complicated because it has the added functionality + listed above, and because we want to re-use features elsewhere. The code can + be compressed down into about ~20 neat lines for `hook_intercept_ds` if we + select and chose fixed `cfg` values. """ REQUIRED_OBS = 2 @@ -73,20 +79,20 @@ def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequ """ d0_posterior, d1_posterior = ds_posterior # shared elements that need to be averaged, computed per pair in the batch. - share_mask = self.compute_posterior_shared_mask(d0_posterior, d1_posterior, thresh_mode=self.cfg.ada_thresh_mode, ratio=self.cfg.ada_thresh_ratio) + share_mask = self.compute_shared_mask_from_posteriors(d0_posterior, d1_posterior, thresh_mode=self.cfg.ada_thresh_mode, ratio=self.cfg.ada_thresh_ratio) # compute average posteriors - new_ds_posterior = self.make_averaged_distributions(d0_posterior, d1_posterior, share_mask, average_mode=self.cfg.ada_average_mode) + new_ds_posterior = self.make_shared_posteriors(d0_posterior, d1_posterior, share_mask, average_mode=self.cfg.ada_average_mode) # return new args & generate logs return new_ds_posterior, ds_prior, { 'shared': share_mask.sum(dim=1).float().mean() } # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - # HELPER # + # HELPER - POSTERIORS # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @classmethod - def compute_posterior_deltas(cls, d0_posterior: Distribution, d1_posterior: Distribution, thresh_mode: str): + def compute_deltas_from_posteriors(cls, d0_posterior: Distribution, d1_posterior: Distribution, thresh_mode: str): """ (✓) Visual inspection against reference implementation https://github.com/google-research/disentanglement_lib (compute_kl) @@ -98,16 +104,16 @@ def compute_posterior_deltas(cls, d0_posterior: Distribution, d1_posterior: Dist # [𝛿_i ...] if thresh_mode == 'kl': # ORIGINAL - deltas = torch.distributions.kl_divergence(d1_posterior, d0_posterior) + deltas = kl_divergence(d1_posterior, d0_posterior) elif thresh_mode == 'symmetric_kl': # FROM: https://openreview.net/pdf?id=8VXvj1QNRl1 - kl_deltas_d1_d0 = torch.distributions.kl_divergence(d1_posterior, d0_posterior) - kl_deltas_d0_d1 = torch.distributions.kl_divergence(d0_posterior, d1_posterior) + kl_deltas_d1_d0 = kl_divergence(d1_posterior, d0_posterior) + kl_deltas_d0_d1 = kl_divergence(d0_posterior, d1_posterior) deltas = (0.5 * kl_deltas_d1_d0) + (0.5 * kl_deltas_d0_d1) elif thresh_mode == 'dist': - deltas = cls.compute_z_deltas(d1_posterior.mean, d0_posterior.mean) + deltas = cls.compute_deltas_from_zs(d1_posterior.mean, d0_posterior.mean) elif thresh_mode == 'sampled_dist': - deltas = cls.compute_z_deltas(d1_posterior.rsample(), d0_posterior.rsample()) + deltas = cls.compute_deltas_from_zs(d1_posterior.rsample(), d0_posterior.rsample()) else: raise KeyError(f'invalid thresh_mode: {repr(thresh_mode)}') @@ -115,23 +121,52 @@ def compute_posterior_deltas(cls, d0_posterior: Distribution, d1_posterior: Dist return deltas @classmethod - def compute_posterior_shared_mask(cls, d0_posterior: Distribution, d1_posterior: Distribution, thresh_mode: str, ratio=0.5): - return cls.estimate_shared_mask(z_deltas=cls.compute_posterior_deltas(d0_posterior, d1_posterior, thresh_mode=thresh_mode), ratio=ratio) + def compute_shared_mask_from_posteriors(cls, d0_posterior: Distribution, d1_posterior: Distribution, thresh_mode: str, ratio=0.5): + return cls.estimate_shared_mask(z_deltas=cls.compute_deltas_from_posteriors(d0_posterior, d1_posterior, thresh_mode=thresh_mode), ratio=ratio) @classmethod - def compute_z_deltas(cls, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor: + def make_shared_posteriors(cls, d0_posterior: Normal, d1_posterior: Normal, share_mask: torch.Tensor, average_mode: str) -> Tuple[Normal, Normal]: + # compute average posterior + ave_posterior = AdaVae.compute_average_distribution(d0_posterior=d0_posterior, d1_posterior=d1_posterior, average_mode=average_mode) + # select shared elements + ave_d0_posterior = Normal(loc=torch.where(share_mask, ave_posterior.loc, d0_posterior.loc), scale=torch.where(share_mask, ave_posterior.scale, d0_posterior.scale)) + ave_d1_posterior = Normal(loc=torch.where(share_mask, ave_posterior.loc, d1_posterior.loc), scale=torch.where(share_mask, ave_posterior.scale, d1_posterior.scale)) + # return values + return ave_d0_posterior, ave_d1_posterior + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # HELPER - MEAN/MU VALUES (same functionality as posterior versions) # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + @classmethod + def compute_deltas_from_zs(cls, z0: torch.Tensor, z1: torch.Tensor) -> torch.Tensor: return torch.abs(z0 - z1) @classmethod - def compute_z_shared_mask(cls, z0: torch.Tensor, z1: torch.Tensor, ratio: float = 0.5): - return cls.estimate_shared_mask(z_deltas=cls.compute_z_deltas(z0, z1), ratio=ratio) + def compute_shared_mask_from_zs(cls, z0: torch.Tensor, z1: torch.Tensor, ratio: float = 0.5): + return cls.estimate_shared_mask(z_deltas=cls.compute_deltas_from_zs(z0, z1), ratio=ratio) + + @classmethod + def make_shared_zs(cls, z0: torch.Tensor, z1: torch.Tensor, share_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # compute average values + ave = 0.5 * z0 + 0.5 * z1 + # select shared elements + ave_z0 = torch.where(share_mask, ave, z0) + ave_z1 = torch.where(share_mask, ave, z1) + return ave_z0, ave_z1 + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # HELPER - COMMON # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @classmethod - def estimate_shared_mask(cls, z_deltas: torch.Tensor, ratio: float = 0.5): + def estimate_shared_mask(cls, z_deltas: torch.Tensor, ratio: float) -> torch.Tensor: """ Core of the adaptive VAE algorithm, estimating which factors have changed (or in this case which are shared and should remained unchanged by being be averaged) between pairs of observations. + - custom ratio is an addition, when ratio==0.5 then + this is equivalent to the original implementation. (✓) Visual inspection against reference implementation: https://github.com/google-research/disentanglement_lib (aggregate_argmax) @@ -144,48 +179,26 @@ def estimate_shared_mask(cls, z_deltas: torch.Tensor, ratio: float = 0.5): and enforce that each factor of variation is encoded in a single dimension." """ + assert 0 <= ratio <= 1, f'ratio must be in the range: 0 <= ratio <= 1, got: {repr(ratio)}' # threshold τ - z_threshs = cls.estimate_threshold(z_deltas, ratio=ratio) + maximums = z_deltas.max(axis=1, keepdim=True).values # (B, 1) + minimums = z_deltas.min(axis=1, keepdim=True).values # (B, 1) + z_threshs = torch.lerp(minimums, maximums, weight=ratio) # (B, 1) # true if 'unchanged' and should be average - shared_mask = z_deltas < z_threshs + shared_mask = z_deltas < z_threshs # broadcast (B, Z) and (B, 1) -> (B, Z) # return return shared_mask - @classmethod - def estimate_threshold(cls, kl_deltas: torch.Tensor, keepdim: bool = True, ratio: float = 0.5): - """ - Compute the threshold for each image pair in a batch of kl divergences of all elements of the latent distributions. - It should be noted that for a perfectly trained model, this threshold is always correct. - - (✓) Visual inspection against reference implementation: - https://github.com/google-research/disentanglement_lib (aggregate_argmax) - """ - maximums = kl_deltas.max(axis=1, keepdim=keepdim).values - minimums = kl_deltas.min(axis=1, keepdim=keepdim).values - return torch.lerp(minimums, maximums, weight=ratio) + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # HELPER - DISTRIBUTIONS # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # @classmethod - def make_averaged_distributions(cls, d0_posterior: Normal, d1_posterior: Normal, share_mask: torch.Tensor, average_mode: str): - # compute average posterior - ave_posterior = compute_average_distribution(d0_posterior=d0_posterior, d1_posterior=d1_posterior, average_mode=average_mode) - # select averages - ave_z0_posterior = ave_posterior.__class__( - loc=torch.where(share_mask, ave_posterior.loc, d0_posterior.loc), - scale=torch.where(share_mask, ave_posterior.scale, d0_posterior.scale), - ) - ave_z1_posterior = ave_posterior.__class__( - loc=torch.where(share_mask, ave_posterior.loc, d1_posterior.loc), - scale=torch.where(share_mask, ave_posterior.scale, d1_posterior.scale), + def compute_average_distribution(cls, d0_posterior: Normal, d1_posterior: Normal, average_mode: str) -> Normal: + return _COMPUTE_AVE_FNS[average_mode]( + d0_posterior=d0_posterior, + d1_posterior=d1_posterior, ) - # return values - return ave_z0_posterior, ave_z1_posterior - - @classmethod - def make_averaged_zs(cls, z0: torch.Tensor, z1: torch.Tensor, share_mask: torch.Tensor): - ave = 0.5 * z0 + 0.5 * z1 - ave_z0 = torch.where(share_mask, ave, z0) - ave_z1 = torch.where(share_mask, ave, z1) - return ave_z0, ave_z1 # ========================================================================= # @@ -193,7 +206,24 @@ def make_averaged_zs(cls, z0: torch.Tensor, z1: torch.Tensor, share_mask: torch. # ========================================================================= # -def compute_average_gvae(z0_mean: torch.Tensor, z0_var: torch.Tensor, z1_mean: torch.Tensor, z1_var: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def compute_average_gvae_std(d0_posterior: Normal, d1_posterior: Normal) -> Normal: + """ + Compute the arithmetic mean of the encoder distributions. + - This is a custom function based on the Ada-GVAE averaging, + except over the standard deviation instead of the variance! + + *NB* this is un-official! + """ + assert isinstance(d0_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d0_posterior)}' + assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' + # averages + ave_std = 0.5 * (d0_posterior.stddev + d1_posterior.stddev) + ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean) + # done! + return Normal(loc=ave_mean, scale=ave_std) + + +def compute_average_gvae(d0_posterior: Normal, d1_posterior: Normal) -> Normal: """ Compute the arithmetic mean of the encoder distributions. - Ada-GVAE Averaging function @@ -201,15 +231,16 @@ def compute_average_gvae(z0_mean: torch.Tensor, z0_var: torch.Tensor, z1_mean: t (✓) Visual inspection against reference implementation: https://github.com/google-research/disentanglement_lib (GroupVAEBase.model_fn) """ - # TODO: would the mean of the std be better? + assert isinstance(d0_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d0_posterior)}' + assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' # averages - ave_var = 0.5 * (z0_var + z1_var) - ave_mean = 0.5 * (z0_mean + z1_mean) - # mean, logvar - return ave_mean, ave_var # natural log + ave_var = 0.5 * (d0_posterior.variance + d1_posterior.variance) + ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean) + # done! + return Normal(loc=ave_mean, scale=torch.sqrt(ave_var)) -def compute_average_ml_vae(z0_mean: torch.Tensor, z0_var: torch.Tensor, z1_mean: torch.Tensor, z1_var: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def compute_average_ml_vae(d0_posterior: Normal, d1_posterior: Normal) -> Normal: """ Compute the product of the encoder distributions. - Ada-ML-VAE Averaging function @@ -219,42 +250,93 @@ def compute_average_ml_vae(z0_mean: torch.Tensor, z0_var: torch.Tensor, z1_mean: # TODO: recheck """ + assert isinstance(d0_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d0_posterior)}' + assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' # Diagonal matrix inverse: E^-1 = 1 / E # https://proofwiki.org/wiki/Inverse_of_Diagonal_Matrix - z0_invvar, z1_invvar = z0_var.reciprocal(), z1_var.reciprocal() + z0_invvar, z1_invvar = d0_posterior.variance.reciprocal(), d1_posterior.variance.reciprocal() # average var: E^-1 = E1^-1 + E2^-1 # disentanglement_lib: ave_var = 2 * z0_var * z1_var / (z0_var + z1_var) ave_var = 2 * (z0_invvar + z1_invvar).reciprocal() # average mean: u^T = (u1^T E1^-1 + u2^T E2^-1) E # disentanglement_lib: ave_mean = (z0_mean/z0_var + z1_mean/z1_var) * ave_var * 0.5 - ave_mean = (z0_mean*z0_invvar + z1_mean*z1_invvar) * ave_var * 0.5 - # mean, logvar - return ave_mean, ave_var # natural log + ave_mean = (d0_posterior.mean*z0_invvar + d1_posterior.mean*z1_invvar) * ave_var * 0.5 + # done! + return Normal(loc=ave_mean, scale=torch.sqrt(ave_var)) -COMPUTE_AVE_FNS = { +_COMPUTE_AVE_FNS = { 'gvae': compute_average_gvae, 'ml-vae': compute_average_ml_vae, + 'gvae_std': compute_average_gvae_std, # this is un-official! } -def compute_average(z0_mean: torch.Tensor, z0_var: torch.Tensor, z1_mean: torch.Tensor, z1_var: torch.Tensor, average_mode: str) -> Tuple[torch.Tensor, torch.Tensor]: - return COMPUTE_AVE_FNS[average_mode](z0_mean=z0_mean, z0_var=z0_var, z1_mean=z1_mean, z1_var=z1_var) +# ========================================================================= # +# Ada-GVAE # +# ========================================================================= # -def compute_average_distribution(d0_posterior: Normal, d1_posterior: Normal, average_mode: str) -> Normal: - assert isinstance(d0_posterior, Normal) and isinstance(d1_posterior, Normal) - ave_mean, ave_var = compute_average( - z0_mean=d0_posterior.mean, z0_var=d0_posterior.variance, - z1_mean=d1_posterior.mean, z1_var=d1_posterior.variance, - average_mode=average_mode, - ) - return Normal(loc=ave_mean, scale=torch.sqrt(ave_var)) +class AdaGVaeMinimal(BetaVae): + """ + This is a direct implementation of the Ada-GVAE, + which should be equivalent to the AdaVae with config values: + + >>> AdaVae.cfg( + >>> ada_average_mode='gvae', + >>> ada_thresh_mode='symmetric_kl', + >>> ada_thresh_ratio=0.5, + >>> ) + """ + REQUIRED_OBS = 2 -# def compute_average_params(z0_params: 'Params', z1_params: 'Params', average_mode: str) -> 'Params': -# ave_mean, ave_logvar = compute_average(z0_params.mean, z0_params.logvar, z1_params.mean, z1_params.logvar, average_mode=average_mode) -# return z0_params.__class__(ave_mean, ave_logvar) + def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]) -> Tuple[Sequence[Distribution], Sequence[Distribution], Dict[str, Any]]: + """ + Adaptive VAE Method, putting the various components together + 1. compute differences between representations + 2. estimate a threshold for differences + 3. compute a shared mask from this threshold + 4. average together elements that are marked as shared + + (x) Visual inspection against reference implementation: + https://github.com/google-research/disentanglement_lib (aggregate_argmax) + """ + d0_posterior, d1_posterior = ds_posterior + assert isinstance(d0_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d0_posterior)}' + assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}' + + # [1] symmetric KL Divergence FROM: https://openreview.net/pdf?id=8VXvj1QNRl1 + z_deltas = 0.5 * kl_divergence(d1_posterior, d0_posterior) + 0.5 * kl_divergence(d0_posterior, d1_posterior) + + # [2] estimate threshold from deltas + z_deltas_min = z_deltas.min(axis=1, keepdim=True).values # (B, 1) + z_deltas_max = z_deltas.max(axis=1, keepdim=True).values # (B, 1) + z_thresh = (0.5 * z_deltas_min + 0.5 * z_deltas_max) # (B, 1) + + # [3] shared elements that need to be averaged, computed per pair in the batch + share_mask = z_deltas < z_thresh # broadcast (B, Z) and (B, 1) to get (B, Z) + + # [4.a] compute average representations + # - this is the only difference between the Ada-ML-VAE + ave_mean = (0.5 * d0_posterior.mean + 0.5 * d1_posterior.mean) + ave_std = (0.5 * d0_posterior.variance + 0.5 * d1_posterior.variance) ** 0.5 + + # [4.b] select shared or original values based on mask + z0_mean = torch.where(share_mask, d0_posterior.loc, ave_mean) + z1_mean = torch.where(share_mask, d1_posterior.loc, ave_mean) + z0_std = torch.where(share_mask, d0_posterior.scale, ave_std) + z1_std = torch.where(share_mask, d1_posterior.scale, ave_std) + + # construct distributions + ave_d0_posterior = Normal(loc=z0_mean, scale=z0_std) + ave_d1_posterior = Normal(loc=z1_mean, scale=z1_std) + new_ds_posterior = (ave_d0_posterior, ave_d1_posterior) + + # [done] return new args & generate logs + return new_ds_posterior, ds_prior, { + 'shared': share_mask.sum(dim=1).float().mean() + } # ========================================================================= # diff --git a/disent/metrics/_dci.py b/disent/metrics/_dci.py index 5c3d67fb..2c48bcd5 100644 --- a/disent/metrics/_dci.py +++ b/disent/metrics/_dci.py @@ -45,7 +45,7 @@ def metric_dci( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function: callable, num_train: int = 10000, num_test: int = 5000, @@ -55,7 +55,7 @@ def metric_dci( ): """Computes the DCI scores according to Sec 2. Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. num_train: Number of points used for training. @@ -70,10 +70,10 @@ def metric_dci( log.debug("Generating training set.") # mus_train are of shape [num_codes, num_train], while ys_train are of shape # [num_factors, num_train]. - mus_train, ys_train = utils.generate_batch_factor_code(ground_truth_dataset, representation_function, num_train, batch_size, show_progress=False) + mus_train, ys_train = utils.generate_batch_factor_code(dataset, representation_function, num_train, batch_size, show_progress=False) assert mus_train.shape[1] == num_train assert ys_train.shape[1] == num_train - mus_test, ys_test = utils.generate_batch_factor_code(ground_truth_dataset, representation_function, num_test, batch_size, show_progress=False) + mus_test, ys_test = utils.generate_batch_factor_code(dataset, representation_function, num_test, batch_size, show_progress=False) log.debug("Computing DCI metric.") scores = _compute_dci(mus_train, ys_train, mus_test, ys_test, boost_mode=boost_mode, show_progress=show_progress) diff --git a/disent/metrics/_factor_vae.py b/disent/metrics/_factor_vae.py index 309346b4..66ffc67a 100644 --- a/disent/metrics/_factor_vae.py +++ b/disent/metrics/_factor_vae.py @@ -44,7 +44,7 @@ def metric_factor_vae( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function: callable, batch_size: int = 64, num_train: int = 10000, @@ -84,7 +84,7 @@ def metric_factor_vae( Most importantly, it circumvents the failure mode of the earlier metric, since the classifier needs to see the lowest variance in a latent dimension for a given factor to classify it correctly Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. batch_size: Number of points to be used to compute the training_sample. @@ -99,7 +99,7 @@ def metric_factor_vae( """ log.debug("Computing global variances to standardise.") - global_variances = _compute_variances(ground_truth_dataset, representation_function, num_variance_estimate) + global_variances = _compute_variances(dataset, representation_function, num_variance_estimate) active_dims = _prune_dims(global_variances) if not active_dims.any(): @@ -110,7 +110,7 @@ def metric_factor_vae( } log.debug("Generating training set.") - training_votes = _generate_training_batch(ground_truth_dataset, representation_function, batch_size, num_train, global_variances, active_dims, show_progress=show_progress) + training_votes = _generate_training_batch(dataset, representation_function, batch_size, num_train, global_variances, active_dims, show_progress=show_progress) classifier = np.argmax(training_votes, axis=0) other_index = np.arange(training_votes.shape[1]) @@ -118,7 +118,7 @@ def metric_factor_vae( train_accuracy = np.sum(training_votes[classifier, other_index]) * 1. / np.sum(training_votes) log.debug("Generating evaluation set.") - eval_votes = _generate_training_batch(ground_truth_dataset, representation_function, batch_size, num_eval, global_variances, active_dims, show_progress=show_progress) + eval_votes = _generate_training_batch(dataset, representation_function, batch_size, num_eval, global_variances, active_dims, show_progress=show_progress) # Evaluate evaluation set accuracy eval_accuracy = np.sum(eval_votes[classifier, other_index]) * 1. / np.sum(eval_votes) @@ -137,21 +137,21 @@ def _prune_dims(variances, threshold=0.): def _compute_variances( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function: callable, batch_size: int, eval_batch_size: int = 64 ): """Computes the variance for each dimension of the representation. Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observation as input and outputs a representation. batch_size: Number of points to be used to compute the variances. eval_batch_size: Batch size used to eval representation. Returns: Vector with the variance of each dimension. """ - observations = ground_truth_dataset.dataset_sample_batch(batch_size, mode='input') + observations = dataset.dataset_sample_batch(batch_size, mode='input') representations = to_numpy(utils.obtain_representation(observations, representation_function, eval_batch_size)) representations = np.transpose(representations) assert representations.shape[0] == batch_size @@ -159,7 +159,7 @@ def _compute_variances( def _generate_training_sample( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function: callable, batch_size: int, global_variances: np.ndarray, @@ -167,7 +167,7 @@ def _generate_training_sample( ) -> (int, int): """Sample a single training sample based on a mini-batch of ground-truth data. Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observation as input and outputs a representation. batch_size: Number of points to be used to compute the training_sample. @@ -178,13 +178,13 @@ def _generate_training_sample( argmin: Index of representation coordinate with the least variance. """ # Select random coordinate to keep fixed. - factor_index = np.random.randint(ground_truth_dataset.ground_truth_data.num_factors) + factor_index = np.random.randint(dataset.gt_data.num_factors) # Sample two mini batches of latent variables. - factors = ground_truth_dataset.ground_truth_data.sample_factors(batch_size) + factors = dataset.gt_data.sample_factors(batch_size) # Fix the selected factor across mini-batch. factors[:, factor_index] = factors[0, factor_index] # Obtain the observations. - observations = ground_truth_dataset.dataset_batch_from_factors(factors, mode='input') + observations = dataset.dataset_batch_from_factors(factors, mode='input') representations = to_numpy(representation_function(observations)) local_variances = np.var(representations, axis=0, ddof=1) argmin = np.argmin(local_variances[active_dims] / global_variances[active_dims]) @@ -192,7 +192,7 @@ def _generate_training_sample( def _generate_training_batch( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function: callable, batch_size: int, num_points: int, @@ -202,7 +202,7 @@ def _generate_training_batch( ): """Sample a set of training samples based on a batch of ground-truth data. Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. batch_size: Number of points to be used to compute the training_sample. num_points: Number of points to be sampled for training set. @@ -211,9 +211,9 @@ def _generate_training_batch( Returns: (num_factors, dim_representation)-sized numpy array with votes. """ - votes = np.zeros((ground_truth_dataset.ground_truth_data.num_factors, global_variances.shape[0]), dtype=np.int64) + votes = np.zeros((dataset.gt_data.num_factors, global_variances.shape[0]), dtype=np.int64) for _ in tqdm(range(num_points), disable=(not show_progress)): - factor_index, argmin = _generate_training_sample(ground_truth_dataset, representation_function, batch_size, global_variances, active_dims) + factor_index, argmin = _generate_training_sample(dataset, representation_function, batch_size, global_variances, active_dims) votes[factor_index, argmin] += 1 return votes diff --git a/disent/metrics/_mig.py b/disent/metrics/_mig.py index df37ba97..3e0b1a87 100644 --- a/disent/metrics/_mig.py +++ b/disent/metrics/_mig.py @@ -43,14 +43,14 @@ def metric_mig( - ground_truth_data: DisentDataset, + dataset: DisentDataset, representation_function, num_train=10000, batch_size=16, ): """Computes the mutual information gap. Args: - ground_truth_data: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. num_train: Number of points used for training. @@ -59,7 +59,7 @@ def metric_mig( Dict with average mutual information gap. """ log.debug("Generating training set.") - mus_train, ys_train = utils.generate_batch_factor_code(ground_truth_data, representation_function, num_train, batch_size) + mus_train, ys_train = utils.generate_batch_factor_code(dataset, representation_function, num_train, batch_size) assert mus_train.shape[1] == num_train return _compute_mig(mus_train, ys_train) diff --git a/disent/metrics/_sap.py b/disent/metrics/_sap.py index 42be33a9..36fad48d 100644 --- a/disent/metrics/_sap.py +++ b/disent/metrics/_sap.py @@ -44,7 +44,7 @@ def metric_sap( - ground_truth_data: DisentDataset, + dataset: DisentDataset, representation_function, num_train=10000, num_test=5000, @@ -53,7 +53,7 @@ def metric_sap( ): """Computes the SAP score. Args: - ground_truth_data: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. num_train: Number of points used for training. @@ -64,8 +64,8 @@ def metric_sap( Dictionary with SAP score. """ log.debug("Generating training set.") - mus, ys = utils.generate_batch_factor_code(ground_truth_data, representation_function, num_train, batch_size) - mus_test, ys_test = utils.generate_batch_factor_code(ground_truth_data, representation_function, num_test, batch_size) + mus, ys = utils.generate_batch_factor_code(dataset, representation_function, num_train, batch_size) + mus_test, ys_test = utils.generate_batch_factor_code(dataset, representation_function, num_test, batch_size) log.debug("Computing score matrix.") return _compute_sap(mus, ys, mus_test, ys_test, continuous_factors) diff --git a/disent/metrics/_unsupervised.py b/disent/metrics/_unsupervised.py index 8cea782d..c872f110 100644 --- a/disent/metrics/_unsupervised.py +++ b/disent/metrics/_unsupervised.py @@ -41,14 +41,14 @@ def metric_unsupervised( - ground_truth_data: DisentDataset, + dataset: DisentDataset, representation_function, num_train=10000, batch_size=16 ): """Computes unsupervised scores based on covariance and mutual information. Args: - ground_truth_data: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. random_state: Numpy random state used for randomness. @@ -59,7 +59,7 @@ def metric_unsupervised( Dictionary with scores. """ log.debug("Generating training set.") - mus_train, _ = utils.generate_batch_factor_code(ground_truth_data, representation_function, num_train, batch_size) + mus_train, _ = utils.generate_batch_factor_code(dataset, representation_function, num_train, batch_size) num_codes = mus_train.shape[0] cov_mus = np.cov(mus_train) assert num_codes == cov_mus.shape[0] diff --git a/disent/metrics/utils.py b/disent/metrics/utils.py index cdb18006..e8c6b2fe 100644 --- a/disent/metrics/utils.py +++ b/disent/metrics/utils.py @@ -37,15 +37,15 @@ def generate_batch_factor_code( - ground_truth_dataset: DisentDataset, + dataset: DisentDataset, representation_function, - num_points, - batch_size, - show_progress=False, + num_points: int, + batch_size: int, + show_progress: bool = False, ): """Sample a single training sample based on a mini-batch of ground-truth data. Args: - ground_truth_dataset: GroundTruthData to be sampled from. + dataset: DisentDataset to be sampled from. representation_function: Function that takes observation as input and outputs a representation. num_points: Number of points to sample. batch_size: Batchsize to sample points. @@ -62,7 +62,7 @@ def generate_batch_factor_code( with tqdm(total=num_points, disable=not show_progress) as bar: while i < num_points: num_points_iter = min(num_points - i, batch_size) - current_observations, current_factors = ground_truth_dataset.dataset_sample_batch_with_factors(num_points_iter, mode='input') + current_observations, current_factors = dataset.dataset_sample_batch_with_factors(num_points_iter, mode='input') if i == 0: factors = current_factors representations = to_numpy(representation_function(current_observations)) diff --git a/disent/model/_base.py b/disent/model/_base.py index 85acfa3c..8a9c6203 100644 --- a/disent/model/_base.py +++ b/disent/model/_base.py @@ -50,10 +50,10 @@ class DisentLatentsModule(DisentModule): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): super().__init__() - self._x_shape = x_shape + self._x_shape = tuple(x_shape) self._x_size = int(np.prod(x_shape)) - self._z_size = z_size - self._z_multiplier = z_multiplier + self._z_size = int(z_size) + self._z_multiplier = int(z_multiplier) def forward(self, *args, **kwargs): raise NotImplementedError @@ -86,7 +86,7 @@ class DisentEncoder(DisentLatentsModule): def forward(self, x, chunk=True) -> torch.Tensor: """same as self.encode but with size checks""" # checks - assert x.ndim == 4, f'ndim mismatch: 4 (required) != {x.ndim} (given)' + assert x.ndim == 4, f'ndim mismatch: 4 (required) != {x.ndim} (given) [shape={x.shape}]' assert x.shape[1:] == self.x_shape, f'x_shape mismatch: {self.x_shape} (required) != {x.shape[1:]} (batch)' # encode | p(z|x) # for a gaussian encoder, we treat z as concat(z_mean, z_logvar) where z_mean.shape == z_logvar.shape diff --git a/disent/model/ae/__init__.py b/disent/model/ae/__init__.py index 7d395b28..29f66f84 100644 --- a/disent/model/ae/__init__.py +++ b/disent/model/ae/__init__.py @@ -29,5 +29,5 @@ from disent.model.ae._norm_conv64 import EncoderConv64Norm from disent.model.ae._vae_fc import DecoderFC from disent.model.ae._vae_fc import EncoderFC -from disent.model.ae._test import DecoderTest -from disent.model.ae._test import EncoderTest +from disent.model.ae._linear import DecoderLinear +from disent.model.ae._linear import EncoderLinear diff --git a/disent/model/ae/_test.py b/disent/model/ae/_linear.py similarity index 97% rename from disent/model/ae/_test.py rename to disent/model/ae/_linear.py index f6c1ee34..c4f9dbe5 100644 --- a/disent/model/ae/_test.py +++ b/disent/model/ae/_linear.py @@ -35,7 +35,7 @@ # ========================================================================= # -class EncoderTest(DisentEncoder): +class EncoderLinear(DisentEncoder): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): super().__init__(x_shape=x_shape, z_size=z_size, z_multiplier=z_multiplier) @@ -49,7 +49,7 @@ def encode(self, x) -> (Tensor, Tensor): return self.model(x) -class DecoderTest(DisentDecoder): +class DecoderLinear(DisentDecoder): def __init__(self, x_shape=(3, 64, 64), z_size=6, z_multiplier=1): super().__init__(x_shape=x_shape, z_size=z_size, z_multiplier=z_multiplier) diff --git a/disent/nn/functional/__init__.py b/disent/nn/functional/__init__.py index be20449a..4ac90f16 100644 --- a/disent/nn/functional/__init__.py +++ b/disent/nn/functional/__init__.py @@ -22,569 +22,45 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import logging -import warnings -from typing import List -from typing import Optional -from typing import Union - -import numpy as np -import torch - -from disent.nn.functional._generic_tensors import generic_as_int32 -from disent.nn.functional._generic_tensors import generic_max -from disent.nn.functional._generic_tensors import TypeGenericTensor -from disent.nn.functional._generic_tensors import TypeGenericTorch - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# pytorch math correlation functions # -# ========================================================================= # - - -def torch_cov_matrix(xs: torch.Tensor): - """ - Calculate the covariance matrix of multiple samples (N) of random vectors of size (X) - https://en.wikipedia.org/wiki/Covariance_matrix - - The input shape is: (N, X) - - The output shape is: (X, X) - - This should be the same as: - np.cov(xs, rowvar=False, ddof=0) - """ - # NOTE: - # torch.mm is strict matrix multiplication - # however if we multiply arrays with broadcasting: - # size(3, 1) * size(1, 2) -> size(3, 2) # broadcast, not matmul - # size(1, 3) * size(2, 1) -> size(2, 3) # broadcast, not matmul - # CHECK: - assert xs.ndim == 2 # (N, X) - Rxx = torch.mean(xs[:, :, None] * xs[:, None, :], dim=0) # (X, X) - ux = torch.mean(xs, dim=0) # (X,) - Kxx = Rxx - (ux[:, None] * ux[None, :]) # (X, X) - return Kxx - - -def torch_corr_matrix(xs: torch.Tensor): - """ - Calculate the pearson's correlation matrix of multiple samples (N) of random vectors of size (X) - https://en.wikipedia.org/wiki/Pearson_correlation_coefficient - https://en.wikipedia.org/wiki/Covariance_matrix - - The input shape is: (N, X) - - The output shape is: (X, X) - - This should be the same as: - np.corrcoef(xs, rowvar=False, ddof=0) - """ - Kxx = torch_cov_matrix(xs) - diag_Kxx = torch.rsqrt(torch.diagonal(Kxx)) - corr = Kxx * (diag_Kxx[:, None] * diag_Kxx[None, :]) - return corr - - -def torch_rank_corr_matrix(xs: torch.Tensor): - """ - Calculate the spearman's rank correlation matrix of multiple samples (N) of random vectors of size (X) - https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient - - The input shape is: (N, X) - - The output shape is: (X, X) - - Pearson's correlation measures linear relationships - Spearman's correlation measures monotonic relationships (whether linear or not) - - defined in terms of the pearson's correlation matrix of the rank variables - - TODO: check, be careful of repeated values, this might not give the correct input? - """ - rs = torch.argsort(xs, dim=0, descending=False) - return torch_corr_matrix(rs.to(xs.dtype)) - - -# aliases -torch_pearsons_corr_matrix = torch_corr_matrix -torch_spearmans_corr_matrix = torch_rank_corr_matrix - - -# ========================================================================= # -# pytorch math helper functions # -# ========================================================================= # - - -def torch_tril_mean(mat: torch.Tensor, diagonal=-1): - """ - compute the mean of the lower triangular matrix. - """ - # checks - N, M = mat.shape - assert N == M - assert diagonal == -1 - # compute - n = (N*(N-1))/2 - mean = torch.tril(mat, diagonal=diagonal).sum() / n - # done - return mean - - -# ========================================================================= # -# pytorch mean functions # -# ========================================================================= # - - -_DimTypeHint = Optional[Union[int, List[int]]] - -_POS_INF = float('inf') -_NEG_INF = float('-inf') - -_GENERALIZED_MEAN_MAP = { - 'maximum': _POS_INF, - 'quadratic': 2, - 'arithmetic': 1, - 'geometric': 0, - 'harmonic': -1, - 'minimum': _NEG_INF, -} - - -def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1, keepdim: bool = False): - """ - Compute the generalised mean. - - p is the power - - harmonic mean ≤ geometric mean ≤ arithmetic mean - - If values have the same units: Use the arithmetic mean. - - If values have differing units: Use the geometric mean. - - If values are rates: Use the harmonic mean. - """ - if isinstance(p, str): - p = _GENERALIZED_MEAN_MAP[p] - # compute the specific extreme cases - if p == _POS_INF: - return torch.max(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.max(xs, keepdim=keepdim) - elif p == _NEG_INF: - return torch.min(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.min(xs, keepdim=keepdim) - # compute the number of elements being averaged - if dim is None: - dim = list(range(xs.ndim)) - n = torch.prod(torch.as_tensor(xs.shape)[dim]) - # warn if the type is wrong - if p != 1: - if xs.dtype != torch.float64: - warnings.warn(f'Input tensor to generalised mean might not have the required precision, type is {xs.dtype} not {torch.float64}.') - # compute the specific cases - if p == 0: - # geometric mean - # orig numerically unstable: torch.prod(xs, dim=dim) ** (1 / n) - return torch.exp((1 / n) * torch.sum(torch.log(xs), dim=dim, keepdim=keepdim)) - elif p == 1: - # arithmetic mean - return torch.mean(xs, dim=dim, keepdim=keepdim) - else: - # generalised mean - return ((1/n) * torch.sum(xs ** p, dim=dim, keepdim=keepdim)) ** (1/p) - - -def torch_mean_quadratic(xs, dim: _DimTypeHint = None, keepdim: bool = False): - return torch_mean_generalized(xs, dim=dim, p='quadratic', keepdim=keepdim) - - -def torch_mean_geometric(xs, dim: _DimTypeHint = None, keepdim: bool = False): - return torch_mean_generalized(xs, dim=dim, p='geometric', keepdim=keepdim) - - -def torch_mean_harmonic(xs, dim: _DimTypeHint = None, keepdim: bool = False): - return torch_mean_generalized(xs, dim=dim, p='harmonic', keepdim=keepdim) - - -# ========================================================================= # -# helper # -# ========================================================================= # - - -def torch_normalize(tensor: torch.Tensor, dims=None, dtype=None): - # get min & max values - if dims is not None: - m, M = tensor, tensor - for dim in dims: - m, M = m.min(dim=dim, keepdim=True).values, M.max(dim=dim, keepdim=True).values - else: - m, M = tensor.min(), tensor.max() - # scale tensor - return (tensor.to(dtype=dtype) - m) / (M - m) # automatically converts to float32 if needed - - -# ========================================================================= # -# polyfill - in later versions of pytorch # -# ========================================================================= # - - -def torch_nan_to_num(input, nan=0.0, posinf=None, neginf=None): - output = input.clone() - if nan is not None: - output[torch.isnan(input)] = nan - if posinf is not None: - output[input == np.inf] = posinf - if neginf is not None: - output[input == -np.inf] = neginf - return output - - -# ========================================================================= # -# PCA # -# ========================================================================= # - - -def torch_pca_eig(X, center=True, scale=False): - """ - perform PCA over X - - X is of size (num_points, vec_size) - - NOTE: unlike PCA_svd, the number of vectors/values returned is always: vec_size - """ - n, _ = X.shape - # center points along axes - if center: - X = X - X.mean(dim=0) - # compute covariance -- TODO: optimise this line - covariance = (1 / (n-1)) * torch.mm(X.T, X) - if scale: - scaling = torch.sqrt(1 / torch.diagonal(covariance)) - covariance = torch.mm(torch.diagflat(scaling), covariance) - # compute eigen values and eigen vectors - eigenvalues, eigenvectors = torch.eig(covariance, True) - # sort components by decreasing variance - components = eigenvectors.T - explained_variance = eigenvalues[:, 0] - idxs = torch.argsort(explained_variance, descending=True) - return components[idxs], explained_variance[idxs] - - -def torch_pca_svd(X, center=True): - """ - perform PCA over X - - X is of size (num_points, vec_size) - - NOTE: unlike PCA_eig, the number of vectors/values returned is: min(num_points, vec_size) - """ - n, _ = X.shape - # center points along axes - if center: - X = X - X.mean(dim=0) - # perform singular value decomposition - u, s, v = torch.svd(X) - # sort components by decreasing variance - # these are already sorted? - components = v.T - explained_variance = torch.mul(s, s) / (n-1) - return components, explained_variance - - -def torch_pca(X, center=True, mode='svd'): - if mode == 'svd': - return torch_pca_svd(X, center=center) - elif mode == 'eig': - return torch_pca_eig(X, center=center, scale=False) - else: - raise KeyError(f'invalid torch_pca mode: {repr(mode)}') - - -# ========================================================================= # -# DCT # -# ========================================================================= # - - -def _flatten_dim_to_end(input, dim): - # get shape - s = input.shape - n = s[dim] - # do operation - x = torch.moveaxis(input, dim, -1) - x = x.reshape(-1, n) - return x, s, n - - -def _unflatten_dim_to_end(input, dim, shape): - # get intermediate shape - s = list(shape) - s.append(s.pop(dim)) - # undo operation - x = input.reshape(*s) - x = torch.moveaxis(x, -1, dim) - return x - - -def torch_dct(x, dim=-1): - """ - Discrete Cosine Transform (DCT) Type II - """ - x, x_shape, n = _flatten_dim_to_end(x, dim=dim) - if n % 2 != 0: - raise ValueError(f'dct does not support odd sized dimension! trying to compute dct over dimension: {dim} of tensor with shape: {x_shape}') - - # concatenate even and odd offsets - v_evn = x[:, 0::2] - v_odd = x[:, 1::2].flip([1]) - v = torch.cat([v_evn, v_odd], dim=-1) - - # fast fourier transform - fft = torch.fft.fft(v) - - # compute real & imaginary forward weights - k = torch.arange(n, dtype=x.dtype, device=x.device) * (-np.pi / (2 * n)) - k = k[None, :] - wr = torch.cos(k) * 2 - wi = torch.sin(k) * 2 - - # compute dct - dct = torch.real(fft) * wr - torch.imag(fft) * wi - - # restore shape - return _unflatten_dim_to_end(dct, dim, x_shape) - - -def torch_idct(dct, dim=-1): - """ - Inverse Discrete Cosine Transform (Inverse DCT) Type III - """ - dct, dct_shape, n = _flatten_dim_to_end(dct, dim=dim) - if n % 2 != 0: - raise ValueError(f'idct does not support odd sized dimension! trying to compute idct over dimension: {dim} of tensor with shape: {dct_shape}') - - # compute real & imaginary backward weights - k = torch.arange(n, dtype=dct.dtype, device=dct.device) * (np.pi / (2 * n)) - k = k[None, :] - wr = torch.cos(k) / 2 - wi = torch.sin(k) / 2 - - dct_real = dct - dct_imag = torch.cat([0*dct_real[:, :1], -dct_real[:, 1:].flip([1])], dim=-1) - - fft_r = dct_real * wr - dct_imag * wi - fft_i = dct_real * wi + dct_imag * wr - # to complex number - fft = torch.view_as_complex(torch.stack([fft_r, fft_i], dim=-1)) - - # inverse fast fourier transform - v = torch.fft.ifft(fft) - v = torch.real(v) - - # undo even and odd offsets - x = torch.zeros_like(dct) - x[:, 0::2] = v[:, :(n+1)//2] # (N+1)//2 == N-(N//2) - x[:, 1::2] += v[:, (n+0)//2:].flip([1]) - - # restore shape - return _unflatten_dim_to_end(x, dim, dct_shape) - - -def torch_dct2(x, dim1=-1, dim2=-2): - d = torch_dct(x, dim=dim1) - d = torch_dct(d, dim=dim2) - return d - - -def torch_idct2(d, dim1=-1, dim2=-2): - x = torch_idct(d, dim=dim2) - x = torch_idct(x, dim=dim1) - return x - - -# ========================================================================= # -# Torch Dim Helper # -# ========================================================================= # - - -def torch_unsqueeze_l(input: torch.Tensor, n: int): - """ - Add n new axis to the left. - - eg. a tensor with shape (2, 3) passed to this function - with n=2 will input in an output shape of (1, 1, 2, 3) - """ - assert n >= 0, f'number of new axis cannot be less than zero, given: {repr(n)}' - return input[((None,)*n) + (...,)] - - -def torch_unsqueeze_r(input: torch.Tensor, n: int): - """ - Add n new axis to the right. - - eg. a tensor with shape (2, 3) passed to this function - with n=2 will input in an output shape of (2, 3, 1, 1) - """ - assert n >= 0, f'number of new axis cannot be less than zero, given: {repr(n)}' - return input[(...,) + ((None,)*n)] - - -# ========================================================================= # -# Kernels # -# ========================================================================= # - - -# TODO: replace with meshgrid based functions from experiment/exp/06_metric -# these are more intuitive and flexible - - -def get_kernel_size(sigma: TypeGenericTensor = 1.0, truncate: TypeGenericTensor = 4.0): - """ - This is how sklearn chooses kernel sizes. - - sigma is the standard deviation, and truncate is the number of deviations away to truncate - - our version broadcasts sigma and truncate together, returning the max kernel size needed over all values - """ - # compute radius - radius = generic_as_int32(truncate * sigma + 0.5) - # get maximum value - radius = int(generic_max(radius)) - # compute diameter - return 2 * radius + 1 - - -def torch_gaussian_kernel( - sigma: TypeGenericTorch = 1.0, truncate: TypeGenericTorch = 4.0, size: int = None, - dtype=torch.float32, device=None, -): - # broadcast tensors together -- data may reference single memory locations - sigma = torch.as_tensor(sigma, dtype=dtype, device=device) - truncate = torch.as_tensor(truncate, dtype=dtype, device=device) - sigma, truncate = torch.broadcast_tensors(sigma, truncate) - # compute default size - if size is None: - size: int = get_kernel_size(sigma=sigma, truncate=truncate) - # compute kernel - x = torch.arange(size, dtype=sigma.dtype, device=sigma.device) - (size - 1) / 2 - # pad tensors correctly - x = torch_unsqueeze_l(x, n=sigma.ndim) - s = torch_unsqueeze_r(sigma, n=1) - # compute - return torch.exp(-(x ** 2) / (2 * s ** 2)) / (np.sqrt(2 * np.pi) * s) - - -def torch_gaussian_kernel_2d( - sigma: TypeGenericTorch = 1.0, truncate: TypeGenericTorch = 4.0, size: int = None, - sigma_b: TypeGenericTorch = None, truncate_b: TypeGenericTorch = None, size_b: int = None, - dtype=torch.float32, device=None, -): - # set default values - if sigma_b is None: sigma_b = sigma - if truncate_b is None: truncate_b = truncate - if size_b is None: size_b = size - # compute kernel - kh = torch_gaussian_kernel(sigma=sigma, truncate=truncate, size=size, dtype=dtype, device=device) - kw = torch_gaussian_kernel(sigma=sigma_b, truncate=truncate_b, size=size_b, dtype=dtype, device=device) - return kh[..., :, None] * kw[..., None, :] - - -def torch_box_kernel(radius: TypeGenericTorch = 1, dtype=torch.float32, device=None): - radius = torch.abs(torch.as_tensor(radius, device=device)) - assert radius.dtype in {torch.int32, torch.int64}, f'box kernel radius must be of integer type: {radius.dtype}' - # box kernel values - radius_max = radius.max() - crange = torch.abs(torch.arange(radius_max * 2 + 1, dtype=dtype, device=device) - radius_max) - # pad everything - radius = radius[..., None] - crange = crange[None, ...] - # compute box kernel - kernel = (crange <= radius).to(dtype) / (radius * 2 + 1) - # done! - return kernel - - -def torch_box_kernel_2d( - radius: TypeGenericTorch = 1, - radius_b: TypeGenericTorch = None, - dtype=torch.float32, device=None -): - # set default values - if radius_b is None: radius_b = radius - # compute kernel - kh = torch_box_kernel(radius=radius, dtype=dtype, device=device) - kw = torch_box_kernel(radius=radius_b, dtype=dtype, device=device) - return kh[..., :, None] * kw[..., None, :] - - -# ========================================================================= # -# convolve # -# ========================================================================= # - - -def _check_conv2d_inputs(signal, kernel): - assert signal.ndim == 4, f'signal has {repr(signal.ndim)} dimensions, must have 4 dimensions instead: BxCxHxW' - assert kernel.ndim == 2 or kernel.ndim == 4, f'kernel has {repr(kernel.ndim)} dimensions, must have 2 or 4 dimensions instead: HxW or BxCxHxW' - # increase kernel size - if kernel.ndim == 2: - kernel = kernel[None, None, ...] - # check kernel is an odd size - kh, kw = kernel.shape[-2:] - assert kh % 2 != 0 and kw % 2 != 0, f'kernel dimension sizes must be odd: ({kh}, {kw})' - # check that broadcasting does not adjust the signal shape... TODO: relax this limitation? - assert torch.broadcast_shapes(signal.shape[:2], kernel.shape[:2]) == signal.shape[:2] - # done! - return signal, kernel - - -def torch_conv2d_channel_wise(signal, kernel): - """ - Apply the kernel to each channel separately! - """ - signal, kernel = _check_conv2d_inputs(signal, kernel) - # split channels into singel images - fsignal = signal.reshape(-1, 1, *signal.shape[2:]) - # convolve each channel image - out = torch.nn.functional.conv2d(fsignal, kernel, padding=(kernel.size(-2) // 2, kernel.size(-1) // 2)) - # reshape into original - return out.reshape(-1, signal.shape[1], *out.shape[2:]) - - -def torch_conv2d_channel_wise_fft(signal, kernel): - """ - The same as torch_conv2d_channel_wise, but apply the kernel using fft. - This is much more efficient for large filter sizes. - - Reference implementation is from: https://github.com/pyro-ppl/pyro/blob/ae55140acfdc6d4eade08b434195234e5ae8c261/pyro/ops/tensor_utils.py#L187 - """ - signal, kernel = _check_conv2d_inputs(signal, kernel) - # get last dimension sizes - sig_shape = np.array(signal.shape[-2:]) - ker_shape = np.array(kernel.shape[-2:]) - # compute padding - padded_shape = sig_shape + ker_shape - 1 - # Compute convolution using fft. - f_signal = torch.fft.rfft2(signal, s=tuple(padded_shape)) - f_kernel = torch.fft.rfft2(kernel, s=tuple(padded_shape)) - result = torch.fft.irfft2(f_signal * f_kernel, s=tuple(padded_shape)) - # crop final result - s = (padded_shape - sig_shape) // 2 - f = s + sig_shape - crop = result[..., s[0]:f[0], s[1]:f[1]] - # done... - return crop - - -# ========================================================================= # -# DEBUG # -# ========================================================================= # - - -def debug_transform_tensors(obj): - """ - recursively convert all tensors to their shapes for debugging - """ - if isinstance(obj, (torch.Tensor, np.ndarray)): - return obj.shape - elif isinstance(obj, dict): - return {debug_transform_tensors(k): debug_transform_tensors(v) for k, v in obj.items()} - elif isinstance(obj, list): - return list(debug_transform_tensors(v) for v in obj) - elif isinstance(obj, tuple): - return tuple(debug_transform_tensors(v) for v in obj) - elif isinstance(obj, set): - return {debug_transform_tensors(k) for k in obj} - else: - return obj - - -# ========================================================================= # -# END # -# ========================================================================= # - +from disent.nn.functional._conv2d import torch_conv2d_channel_wise +from disent.nn.functional._conv2d import torch_conv2d_channel_wise_fft + +from disent.nn.functional._conv2d_kernels import get_kernel_size +from disent.nn.functional._conv2d_kernels import torch_gaussian_kernel +from disent.nn.functional._conv2d_kernels import torch_gaussian_kernel_2d +from disent.nn.functional._conv2d_kernels import torch_box_kernel +from disent.nn.functional._conv2d_kernels import torch_box_kernel_2d + +from disent.nn.functional._correlation import torch_cov_matrix +from disent.nn.functional._correlation import torch_corr_matrix +from disent.nn.functional._correlation import torch_rank_corr_matrix +from disent.nn.functional._correlation import torch_pearsons_corr_matrix +from disent.nn.functional._correlation import torch_spearmans_corr_matrix + +from disent.nn.functional._dct import torch_dct +from disent.nn.functional._dct import torch_idct +from disent.nn.functional._dct import torch_dct2 +from disent.nn.functional._dct import torch_idct2 + +from disent.nn.functional._mean import torch_mean_generalized +from disent.nn.functional._mean import torch_mean_quadratic +from disent.nn.functional._mean import torch_mean_geometric +from disent.nn.functional._mean import torch_mean_harmonic + +from disent.nn.functional._other import torch_normalize +from disent.nn.functional._other import torch_nan_to_num +from disent.nn.functional._other import torch_unsqueeze_l +from disent.nn.functional._other import torch_unsqueeze_r + +from disent.nn.functional._pca import torch_pca_eig +from disent.nn.functional._pca import torch_pca_svd +from disent.nn.functional._pca import torch_pca + +# from disent.nn.functional._util_generic import TypeGenericTensor +# from disent.nn.functional._util_generic import TypeGenericTorch +# from disent.nn.functional._util_generic import TypeGenericNumpy +# from disent.nn.functional._util_generic import generic_as_int32 +# from disent.nn.functional._util_generic import generic_max +# from disent.nn.functional._util_generic import generic_min +# from disent.nn.functional._util_generic import generic_shape +# from disent.nn.functional._util_generic import generic_ndim diff --git a/disent/nn/functional/_conv2d.py b/disent/nn/functional/_conv2d.py new file mode 100644 index 00000000..8fbbb46d --- /dev/null +++ b/disent/nn/functional/_conv2d.py @@ -0,0 +1,89 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +import torch + + +# ========================================================================= # +# convolve2d # +# ========================================================================= # + + +def _check_conv2d_inputs(signal, kernel): + assert signal.ndim == 4, f'signal has {repr(signal.ndim)} dimensions, must have 4 dimensions instead: BxCxHxW' + assert kernel.ndim == 2 or kernel.ndim == 4, f'kernel has {repr(kernel.ndim)} dimensions, must have 2 or 4 dimensions instead: HxW or BxCxHxW' + # increase kernel size + if kernel.ndim == 2: + kernel = kernel[None, None, ...] + # check kernel is an odd size + kh, kw = kernel.shape[-2:] + assert kh % 2 != 0 and kw % 2 != 0, f'kernel dimension sizes must be odd: ({kh}, {kw})' + # check that broadcasting does not adjust the signal shape... TODO: relax this limitation? + assert torch.broadcast_shapes(signal.shape[:2], kernel.shape[:2]) == signal.shape[:2] + # done! + return signal, kernel + + +def torch_conv2d_channel_wise(signal, kernel): + """ + Apply the kernel to each channel separately! + """ + signal, kernel = _check_conv2d_inputs(signal, kernel) + # split channels into singel images + fsignal = signal.reshape(-1, 1, *signal.shape[2:]) + # convolve each channel image + out = torch.nn.functional.conv2d(fsignal, kernel, padding=(kernel.size(-2) // 2, kernel.size(-1) // 2)) + # reshape into original + return out.reshape(-1, signal.shape[1], *out.shape[2:]) + + +def torch_conv2d_channel_wise_fft(signal, kernel): + """ + The same as torch_conv2d_channel_wise, but apply the kernel using fft. + This is much more efficient for large filter sizes. + + Reference implementation is from: https://github.com/pyro-ppl/pyro/blob/ae55140acfdc6d4eade08b434195234e5ae8c261/pyro/ops/tensor_utils.py#L187 + """ + signal, kernel = _check_conv2d_inputs(signal, kernel) + # get last dimension sizes + sig_shape = np.array(signal.shape[-2:]) + ker_shape = np.array(kernel.shape[-2:]) + # compute padding + padded_shape = sig_shape + ker_shape - 1 + # Compute convolution using fft. + f_signal = torch.fft.rfft2(signal, s=tuple(padded_shape)) + f_kernel = torch.fft.rfft2(kernel, s=tuple(padded_shape)) + result = torch.fft.irfft2(f_signal * f_kernel, s=tuple(padded_shape)) + # crop final result + s = (padded_shape - sig_shape) // 2 + f = s + sig_shape + crop = result[..., s[0]:f[0], s[1]:f[1]] + # done... + return crop + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_conv2d_kernels.py b/disent/nn/functional/_conv2d_kernels.py new file mode 100644 index 00000000..e69a1728 --- /dev/null +++ b/disent/nn/functional/_conv2d_kernels.py @@ -0,0 +1,124 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +import torch + +from disent.nn.functional._other import torch_unsqueeze_l +from disent.nn.functional._other import torch_unsqueeze_r +from disent.nn.functional._util_generic import generic_as_int32 +from disent.nn.functional._util_generic import generic_max +from disent.nn.functional._util_generic import TypeGenericTensor +from disent.nn.functional._util_generic import TypeGenericTorch + + +# ========================================================================= # +# Kernels # +# ========================================================================= # + + +# TODO: replace with meshgrid based functions from experiment/exp/06_metric +# these are more intuitive and flexible + + +def get_kernel_size(sigma: TypeGenericTensor = 1.0, truncate: TypeGenericTensor = 4.0): + """ + This is how sklearn chooses kernel sizes. + - sigma is the standard deviation, and truncate is the number of deviations away to truncate + - our version broadcasts sigma and truncate together, returning the max kernel size needed over all values + """ + # compute radius + radius = generic_as_int32(truncate * sigma + 0.5) + # get maximum value + radius = int(generic_max(radius)) + # compute diameter + return 2 * radius + 1 + + +def torch_gaussian_kernel( + sigma: TypeGenericTorch = 1.0, truncate: TypeGenericTorch = 4.0, size: int = None, + dtype=torch.float32, device=None, +): + # broadcast tensors together -- data may reference single memory locations + sigma = torch.as_tensor(sigma, dtype=dtype, device=device) + truncate = torch.as_tensor(truncate, dtype=dtype, device=device) + sigma, truncate = torch.broadcast_tensors(sigma, truncate) + # compute default size + if size is None: + size: int = get_kernel_size(sigma=sigma, truncate=truncate) + # compute kernel + x = torch.arange(size, dtype=sigma.dtype, device=sigma.device) - (size - 1) / 2 + # pad tensors correctly + x = torch_unsqueeze_l(x, n=sigma.ndim) + s = torch_unsqueeze_r(sigma, n=1) + # compute + return torch.exp(-(x ** 2) / (2 * s ** 2)) / (np.sqrt(2 * np.pi) * s) + + +def torch_gaussian_kernel_2d( + sigma: TypeGenericTorch = 1.0, truncate: TypeGenericTorch = 4.0, size: int = None, + sigma_b: TypeGenericTorch = None, truncate_b: TypeGenericTorch = None, size_b: int = None, + dtype=torch.float32, device=None, +): + # set default values + if sigma_b is None: sigma_b = sigma + if truncate_b is None: truncate_b = truncate + if size_b is None: size_b = size + # compute kernel + kh = torch_gaussian_kernel(sigma=sigma, truncate=truncate, size=size, dtype=dtype, device=device) + kw = torch_gaussian_kernel(sigma=sigma_b, truncate=truncate_b, size=size_b, dtype=dtype, device=device) + return kh[..., :, None] * kw[..., None, :] + + +def torch_box_kernel(radius: TypeGenericTorch = 1, dtype=torch.float32, device=None): + radius = torch.abs(torch.as_tensor(radius, device=device)) + assert radius.dtype in {torch.int32, torch.int64}, f'box kernel radius must be of integer type: {radius.dtype}' + # box kernel values + radius_max = radius.max() + crange = torch.abs(torch.arange(radius_max * 2 + 1, dtype=dtype, device=device) - radius_max) + # pad everything + radius = radius[..., None] + crange = crange[None, ...] + # compute box kernel + kernel = (crange <= radius).to(dtype) / (radius * 2 + 1) + # done! + return kernel + + +def torch_box_kernel_2d( + radius: TypeGenericTorch = 1, + radius_b: TypeGenericTorch = None, + dtype=torch.float32, device=None +): + # set default values + if radius_b is None: radius_b = radius + # compute kernel + kh = torch_box_kernel(radius=radius, dtype=dtype, device=device) + kw = torch_box_kernel(radius=radius_b, dtype=dtype, device=device) + return kh[..., :, None] * kw[..., None, :] + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_correlation.py b/disent/nn/functional/_correlation.py new file mode 100644 index 00000000..6618361f --- /dev/null +++ b/disent/nn/functional/_correlation.py @@ -0,0 +1,97 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import torch + + +# ========================================================================= # +# pytorch math correlation functions # +# ========================================================================= # + + +def torch_cov_matrix(xs: torch.Tensor): + """ + Calculate the covariance matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Covariance_matrix + - The input shape is: (N, X) + - The output shape is: (X, X) + + This should be the same as: + np.cov(xs, rowvar=False, ddof=0) + """ + # NOTE: + # torch.mm is strict matrix multiplication + # however if we multiply arrays with broadcasting: + # size(3, 1) * size(1, 2) -> size(3, 2) # broadcast, not matmul + # size(1, 3) * size(2, 1) -> size(2, 3) # broadcast, not matmul + # CHECK: + assert xs.ndim == 2 # (N, X) + Rxx = torch.mean(xs[:, :, None] * xs[:, None, :], dim=0) # (X, X) + ux = torch.mean(xs, dim=0) # (X,) + Kxx = Rxx - (ux[:, None] * ux[None, :]) # (X, X) + return Kxx + + +def torch_corr_matrix(xs: torch.Tensor): + """ + Calculate the pearson's correlation matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Pearson_correlation_coefficient + https://en.wikipedia.org/wiki/Covariance_matrix + - The input shape is: (N, X) + - The output shape is: (X, X) + + This should be the same as: + np.corrcoef(xs, rowvar=False, ddof=0) + """ + Kxx = torch_cov_matrix(xs) + diag_Kxx = torch.rsqrt(torch.diagonal(Kxx)) + corr = Kxx * (diag_Kxx[:, None] * diag_Kxx[None, :]) + return corr + + +def torch_rank_corr_matrix(xs: torch.Tensor): + """ + Calculate the spearman's rank correlation matrix of multiple samples (N) of random vectors of size (X) + https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient + - The input shape is: (N, X) + - The output shape is: (X, X) + + Pearson's correlation measures linear relationships + Spearman's correlation measures monotonic relationships (whether linear or not) + - defined in terms of the pearson's correlation matrix of the rank variables + + TODO: check, be careful of repeated values, this might not give the correct input? + """ + rs = torch.argsort(xs, dim=0, descending=False) + return torch_corr_matrix(rs.to(xs.dtype)) + + +# aliases +torch_pearsons_corr_matrix = torch_corr_matrix +torch_spearmans_corr_matrix = torch_rank_corr_matrix + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_dct.py b/disent/nn/functional/_dct.py new file mode 100644 index 00000000..2c935e79 --- /dev/null +++ b/disent/nn/functional/_dct.py @@ -0,0 +1,132 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +import torch + + +# ========================================================================= # +# Discrete Cosine Transform # +# ========================================================================= # + + +def _flatten_dim_to_end(input, dim): + # get shape + s = input.shape + n = s[dim] + # do operation + x = torch.moveaxis(input, dim, -1) + x = x.reshape(-1, n) + return x, s, n + + +def _unflatten_dim_to_end(input, dim, shape): + # get intermediate shape + s = list(shape) + s.append(s.pop(dim)) + # undo operation + x = input.reshape(*s) + x = torch.moveaxis(x, -1, dim) + return x + + +def torch_dct(x, dim=-1): + """ + Discrete Cosine Transform (DCT) Type II + """ + x, x_shape, n = _flatten_dim_to_end(x, dim=dim) + if n % 2 != 0: + raise ValueError(f'dct does not support odd sized dimension! trying to compute dct over dimension: {dim} of tensor with shape: {x_shape}') + + # concatenate even and odd offsets + v_evn = x[:, 0::2] + v_odd = x[:, 1::2].flip([1]) + v = torch.cat([v_evn, v_odd], dim=-1) + + # fast fourier transform + fft = torch.fft.fft(v) + + # compute real & imaginary forward weights + k = torch.arange(n, dtype=x.dtype, device=x.device) * (-np.pi / (2 * n)) + k = k[None, :] + wr = torch.cos(k) * 2 + wi = torch.sin(k) * 2 + + # compute dct + dct = torch.real(fft) * wr - torch.imag(fft) * wi + + # restore shape + return _unflatten_dim_to_end(dct, dim, x_shape) + + +def torch_idct(dct, dim=-1): + """ + Inverse Discrete Cosine Transform (Inverse DCT) Type III + """ + dct, dct_shape, n = _flatten_dim_to_end(dct, dim=dim) + if n % 2 != 0: + raise ValueError(f'idct does not support odd sized dimension! trying to compute idct over dimension: {dim} of tensor with shape: {dct_shape}') + + # compute real & imaginary backward weights + k = torch.arange(n, dtype=dct.dtype, device=dct.device) * (np.pi / (2 * n)) + k = k[None, :] + wr = torch.cos(k) / 2 + wi = torch.sin(k) / 2 + + dct_real = dct + dct_imag = torch.cat([0*dct_real[:, :1], -dct_real[:, 1:].flip([1])], dim=-1) + + fft_r = dct_real * wr - dct_imag * wi + fft_i = dct_real * wi + dct_imag * wr + # to complex number + fft = torch.view_as_complex(torch.stack([fft_r, fft_i], dim=-1)) + + # inverse fast fourier transform + v = torch.fft.ifft(fft) + v = torch.real(v) + + # undo even and odd offsets + x = torch.zeros_like(dct) + x[:, 0::2] = v[:, :(n+1)//2] # (N+1)//2 == N-(N//2) + x[:, 1::2] += v[:, (n+0)//2:].flip([1]) + + # restore shape + return _unflatten_dim_to_end(x, dim, dct_shape) + + +def torch_dct2(x, dim1=-1, dim2=-2): + d = torch_dct(x, dim=dim1) + d = torch_dct(d, dim=dim2) + return d + + +def torch_idct2(d, dim1=-1, dim2=-2): + x = torch_idct(d, dim=dim2) + x = torch_idct(x, dim=dim1) + return x + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_mean.py b/disent/nn/functional/_mean.py new file mode 100644 index 00000000..00bdf0a0 --- /dev/null +++ b/disent/nn/functional/_mean.py @@ -0,0 +1,110 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import warnings +from typing import List +from typing import Optional +from typing import Union + +import torch + + +# ========================================================================= # +# Types # +# ========================================================================= # + + +_DimTypeHint = Optional[Union[int, List[int]]] + +_POS_INF = float('inf') +_NEG_INF = float('-inf') + +_GENERALIZED_MEAN_MAP = { + 'maximum': _POS_INF, + 'quadratic': 2, + 'arithmetic': 1, + 'geometric': 0, + 'harmonic': -1, + 'minimum': _NEG_INF, +} + + +# ========================================================================= # +# Generalized Mean Functions # +# ========================================================================= # + + +def torch_mean_generalized(xs: torch.Tensor, dim: _DimTypeHint = None, p: Union[int, str] = 1, keepdim: bool = False): + """ + Compute the generalised mean. + - p is the power + + harmonic mean ≤ geometric mean ≤ arithmetic mean + - If values have the same units: Use the arithmetic mean. + - If values have differing units: Use the geometric mean. + - If values are rates: Use the harmonic mean. + """ + if isinstance(p, str): + p = _GENERALIZED_MEAN_MAP[p] + # compute the specific extreme cases + if p == _POS_INF: + return torch.max(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.max(xs, keepdim=keepdim) + elif p == _NEG_INF: + return torch.min(xs, dim=dim, keepdim=keepdim).values if (dim is not None) else torch.min(xs, keepdim=keepdim) + # compute the number of elements being averaged + if dim is None: + dim = list(range(xs.ndim)) + n = torch.prod(torch.as_tensor(xs.shape)[dim]) + # warn if the type is wrong + if p != 1: + if xs.dtype != torch.float64: + warnings.warn(f'Input tensor to generalised mean might not have the required precision, type is {xs.dtype} not {torch.float64}.') + # compute the specific cases + if p == 0: + # geometric mean + # orig numerically unstable: torch.prod(xs, dim=dim) ** (1 / n) + return torch.exp((1 / n) * torch.sum(torch.log(xs), dim=dim, keepdim=keepdim)) + elif p == 1: + # arithmetic mean + return torch.mean(xs, dim=dim, keepdim=keepdim) + else: + # generalised mean + return ((1/n) * torch.sum(xs ** p, dim=dim, keepdim=keepdim)) ** (1/p) + + +def torch_mean_quadratic(xs, dim: _DimTypeHint = None, keepdim: bool = False): + return torch_mean_generalized(xs, dim=dim, p='quadratic', keepdim=keepdim) + + +def torch_mean_geometric(xs, dim: _DimTypeHint = None, keepdim: bool = False): + return torch_mean_generalized(xs, dim=dim, p='geometric', keepdim=keepdim) + + +def torch_mean_harmonic(xs, dim: _DimTypeHint = None, keepdim: bool = False): + return torch_mean_generalized(xs, dim=dim, p='harmonic', keepdim=keepdim) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_other.py b/disent/nn/functional/_other.py new file mode 100644 index 00000000..3b95d837 --- /dev/null +++ b/disent/nn/functional/_other.py @@ -0,0 +1,91 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +import torch + + +# ========================================================================= # +# helper # +# ========================================================================= # + + +def torch_normalize(tensor: torch.Tensor, dims=None, dtype=None): + # get min & max values + if dims is not None: + m, M = tensor, tensor + for dim in dims: + m, M = m.min(dim=dim, keepdim=True).values, M.max(dim=dim, keepdim=True).values + else: + m, M = tensor.min(), tensor.max() + # scale tensor + return (tensor.to(dtype=dtype) - m) / (M - m) # automatically converts to float32 if needed + + +# ========================================================================= # +# polyfill - in later versions of pytorch # +# ========================================================================= # + + +def torch_nan_to_num(input, nan=0.0, posinf=None, neginf=None): + output = input.clone() + if nan is not None: + output[torch.isnan(input)] = nan + if posinf is not None: + output[input == np.inf] = posinf + if neginf is not None: + output[input == -np.inf] = neginf + return output + + +# ========================================================================= # +# Torch Dim Helper # +# ========================================================================= # + + +def torch_unsqueeze_l(input: torch.Tensor, n: int): + """ + Add n new axis to the left. + + eg. a tensor with shape (2, 3) passed to this function + with n=2 will input in an output shape of (1, 1, 2, 3) + """ + assert n >= 0, f'number of new axis cannot be less than zero, given: {repr(n)}' + return input[((None,)*n) + (...,)] + + +def torch_unsqueeze_r(input: torch.Tensor, n: int): + """ + Add n new axis to the right. + + eg. a tensor with shape (2, 3) passed to this function + with n=2 will input in an output shape of (2, 3, 1, 1) + """ + assert n >= 0, f'number of new axis cannot be less than zero, given: {repr(n)}' + return input[(...,) + ((None,)*n)] + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_pca.py b/disent/nn/functional/_pca.py new file mode 100644 index 00000000..33212f41 --- /dev/null +++ b/disent/nn/functional/_pca.py @@ -0,0 +1,89 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import torch + + +# ========================================================================= # +# PCA # +# ========================================================================= # + + +def torch_pca_eig(X, center=True, scale=False): + """ + perform PCA over X + - X is of size (num_points, vec_size) + + NOTE: unlike PCA_svd, the number of vectors/values returned is always: vec_size + """ + n, _ = X.shape + # center points along axes + if center: + X = X - X.mean(dim=0) + # compute covariance -- TODO: optimise this line + covariance = (1 / (n-1)) * torch.mm(X.T, X) + if scale: + scaling = torch.sqrt(1 / torch.diagonal(covariance)) + covariance = torch.mm(torch.diagflat(scaling), covariance) + # compute eigen values and eigen vectors + eigenvalues, eigenvectors = torch.eig(covariance, True) + # sort components by decreasing variance + components = eigenvectors.T + explained_variance = eigenvalues[:, 0] + idxs = torch.argsort(explained_variance, descending=True) + return components[idxs], explained_variance[idxs] + + +def torch_pca_svd(X, center=True): + """ + perform PCA over X + - X is of size (num_points, vec_size) + + NOTE: unlike PCA_eig, the number of vectors/values returned is: min(num_points, vec_size) + """ + n, _ = X.shape + # center points along axes + if center: + X = X - X.mean(dim=0) + # perform singular value decomposition + u, s, v = torch.svd(X) + # sort components by decreasing variance + # these are already sorted? + components = v.T + explained_variance = torch.mul(s, s) / (n-1) + return components, explained_variance + + +def torch_pca(X, center=True, mode='svd'): + if mode == 'svd': + return torch_pca_svd(X, center=center) + elif mode == 'eig': + return torch_pca_eig(X, center=center, scale=False) + else: + raise KeyError(f'invalid torch_pca mode: {repr(mode)}') + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/nn/functional/_generic_tensors.py b/disent/nn/functional/_util_generic.py similarity index 98% rename from disent/nn/functional/_generic_tensors.py rename to disent/nn/functional/_util_generic.py index be0dc299..8a390d3f 100644 --- a/disent/nn/functional/_generic_tensors.py +++ b/disent/nn/functional/_util_generic.py @@ -21,16 +21,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + from typing import Union -import logging import numpy as np import torch -log = logging.getLogger(__name__) - - # ========================================================================= # # Generic Helper Functions # # - torch, numpy, scalars # diff --git a/disent/nn/loss/softsort.py b/disent/nn/loss/softsort.py index 266ac875..2f50d16e 100644 --- a/disent/nn/loss/softsort.py +++ b/disent/nn/loss/softsort.py @@ -21,6 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + from functools import lru_cache from typing import Tuple diff --git a/disent/nn/transform/_transforms.py b/disent/nn/transform/_transforms.py deleted file mode 100644 index 1eb7c255..00000000 --- a/disent/nn/transform/_transforms.py +++ /dev/null @@ -1,100 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import torch -import disent.nn.transform.functional as F_d - - -# ========================================================================= # -# Transforms # -# ========================================================================= # - - -class Noop(object): - """ - Transform that does absolutely nothing! - See: disent.transform.functional.noop - """ - - def __call__(self, obs): - return obs - - def __repr__(self): - return f'{self.__class__.__name__}()' - - -class CheckTensor(object): - """ - Check that the data is a tensor, the right dtype, and in the required range. - See: disent.transform.functional.check_tensor - """ - - def __init__(self, low=0., high=1., dtype=torch.float32): - self._low = low - self._high = high - self._dtype = dtype - - def __call__(self, obs): - return F_d.check_tensor(obs, low=self._low, high=self._high, dtype=self._dtype) - - def __repr__(self): - return f'{self.__class__.__name__}(low={repr(self._low)}, high={repr(self._high)}, dtype={repr(self._dtype)})' - - -class ToStandardisedTensor(object): - """ - Standardise image data after loading, by converting to a tensor - in range [0, 1], and resizing to a square if specified. - See: disent.transform.functional.to_standardised_tensor - """ - - def __init__(self, size: F_d.SizeType = None, cast_f32: bool = False, check: bool = True, check_range: bool = True): - self._size = size - self._cast_f32 = cast_f32 # cast after resizing before checks -- disabled by default to so dtype errors can be seen - self._check = check - self._check_range = check_range # if check is `False` then `check_range` can never be `True` - - def __call__(self, obs) -> torch.Tensor: - return F_d.to_standardised_tensor(obs, size=self._size, cast_f32=self._cast_f32, check=self._check, check_range=self._check_range) - - def __repr__(self): - return f'{self.__class__.__name__}(size={repr(self._size)})' - - -class ToUint8Tensor(object): - - def __init__(self, size: F_d.SizeType = None, channel_to_front: bool = True): - self._size = size - self._channel_to_front = channel_to_front - - def __call__(self, obs) -> torch.Tensor: - return F_d.to_uint_tensor(obs, size=self._size, channel_to_front=self._channel_to_front) - - def __repr__(self): - return f'{self.__class__.__name__}(size={repr(self._size)})' - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/nn/transform/functional.py b/disent/nn/transform/functional.py deleted file mode 100644 index 297afb03..00000000 --- a/disent/nn/transform/functional.py +++ /dev/null @@ -1,128 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from typing import Optional -from typing import Tuple -from typing import Union - -import numpy as np -from PIL.Image import Image -import torch -import torchvision.transforms.functional as F_tv - - -# ========================================================================= # -# Functional Transforms # -# ========================================================================= # - - -def noop(obs): - """ - Transform that does absolutely nothing! - """ - return obs - - -def check_tensor(obs, low: Optional[float] = 0., high: Optional[float] = 1., dtype=torch.float32): - """ - Check that the input is a tensor, its datatype matches, and - that it is in the required range. - """ - # check is a tensor - assert torch.is_tensor(obs), 'observation is not a tensor' - # check type - if dtype is not None: - assert obs.dtype == dtype, f'tensor type {obs.dtype} is not required type {dtype}' - # check range | TODO: are assertion strings inefficient? - if low is not None: - assert low <= obs.min(), f'minimum value of tensor {obs.min()} is less than allowed minimum value: {low}' - if high is not None: - assert obs.max() <= high, f'maximum value of tensor {obs.max()} is greater than allowed maximum value: {high}' - # DONE! - return obs - - -Obs = Union[np.ndarray, Image] -SizeType = Union[int, Tuple[int, int]] - - -def to_uint_tensor( - obs: Obs, - size: Optional[SizeType] = None, - channel_to_front: bool = True -) -> torch.Tensor: - # resize image - if size is not None: - if not isinstance(obs, Image): - obs = F_tv.to_pil_image(obs) - obs = F_tv.resize(obs, size=size) - # to numpy - if not isinstance(obs, np.ndarray): - obs = np.array(obs) - # to tensor - obs = torch.from_numpy(obs) - # move axis - if channel_to_front: - obs = torch.moveaxis(obs, -1, -3) - # checks - assert obs.dtype == torch.uint8 - # done! - return obs - - -def to_standardised_tensor( - obs: Obs, - size: Optional[SizeType] = None, - cast_f32: bool = False, - check: bool = True, - check_range: bool = True, -) -> torch.Tensor: - """ - Basic transform that should be applied to - any dataset before augmentation. - - 1. resize if size is specified - 2. convert to tensor in range [0, 1] - """ - # resize image - if size is not None: - if not isinstance(obs, Image): - obs = F_tv.to_pil_image(obs) - obs = F_tv.resize(obs, size=size) - # transform to tensor - obs = F_tv.to_tensor(obs) - # cast if needed - if cast_f32: - obs = obs.to(torch.float32) - # check that tensor is valid - if check: - if check_range: - obs = check_tensor(obs, low=0, high=1, dtype=torch.float32) - else: - obs = check_tensor(obs, low=None, high=None, dtype=torch.float32) - return obs - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/nn/weights.py b/disent/nn/weights.py index 6be5d9c4..1a09b852 100644 --- a/disent/nn/weights.py +++ b/disent/nn/weights.py @@ -23,6 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional from torch import nn from disent.util.strings import colors as c @@ -36,7 +37,7 @@ # ========================================================================= # -def init_model_weights(model: nn.Module, mode='xavier_normal', log_level=logging.INFO): +def init_model_weights(model: nn.Module, mode: Optional[str] = 'xavier_normal', log_level=logging.INFO) -> nn.Module: count = 0 # get default mode @@ -56,7 +57,7 @@ def init_normal(m): elif mode == 'default': pass else: - raise KeyError(f'Unknown init mode: {repr(mode)}') + raise KeyError(f'Unknown init mode: {repr(mode)}, valid modes are: {["xavier_normal", "default"]}') # print messages if init: diff --git a/disent/registry/__init__.py b/disent/registry/__init__.py index 7136843f..82814540 100644 --- a/disent/registry/__init__.py +++ b/disent/registry/__init__.py @@ -30,226 +30,217 @@ *NB* All modules and classes are lazily imported! -# TODO: this needs to be more flexible - - support custom registration - - support aliases - - support validation of objects - - add factory methods +You can register your own modules and classes using the provided decorator: +eg. `DATASET.register(...options...)(your_function_or_class)` """ +from disent.registry._registry import Registry as _Registry +from disent.registry._registry import LazyImport as _LazyImport -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! -# TODO: this file needs to be cleaned up!!! - -# ========================================================================= # -# Fake Imports # -# ========================================================================= # - - -from disent.registry import _registry_util as _R - - -# this is to trick the PyCharm type hinting system, used in -# conjunction with the `if False: import ...` statements which -# should never actually be run! -_disent = _R.PathBuilder('disent') -_torch = _R.PathBuilder('torch') -_torch_optimizer = _R.PathBuilder('torch_optimizer') - - -if False: - import disent as _disent - import torch as _torch - import torch_optimizer as _torch_optimizer - # force pycharm to load hints - import disent.dataset.data as _ - import disent.dataset.sampling as _ - import disent.frameworks.ae as _ - import disent.frameworks.vae as _ - import disent.frameworks.helper.reconstructions as _ - import disent.frameworks.helper.latent_distributions as _ - import disent.metrics as _ - import disent.schedule as _ - import disent.model.ae as _ - import torch.optim as _ - - -# ========================================================================= # -# Registries # -# TODO: registries should support multiple aliases -# ========================================================================= # - - -# changes here should also update `disent/dataset/data/__init__.py` -class DATASET(metaclass=_R.LazyImportMeta()): - # [groundtruth -- impl] - Cars3d = _disent.dataset.data._groundtruth__cars3d.Cars3dData - DSprites = _disent.dataset.data._groundtruth__dsprites.DSpritesData - Mpi3d = _disent.dataset.data._groundtruth__mpi3d.Mpi3dData - SmallNorb = _disent.dataset.data._groundtruth__norb.SmallNorbData - Shapes3d = _disent.dataset.data._groundtruth__shapes3d.Shapes3dData - XYObject = _disent.dataset.data._groundtruth__xyobject.XYObjectData - - -# changes here should also update `disent/dataset/sampling/__init__.py` -class SAMPLER(metaclass=_R.LazyImportMeta()): - # [ground truth samplers] - GT_Dist = _disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler - GT_Pair = _disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler - GT_PairOrig = _disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler - GT_Single = _disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler - GT_Triple = _disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler - # [any dataset samplers] - Single = _disent.dataset.sampling._single.SingleSampler - Random = _disent.dataset.sampling._random__any.RandomSampler - # [episode samplers] - RandomEpisode = _disent.dataset.sampling._random__episodes.RandomEpisodeSampler - - -# changes here should also update `disent/frameworks/ae/__init__.py` & `disent/frameworks/vae/__init__.py` -class FRAMEWORK(metaclass=_R.LazyImportMeta()): - # [AE] - TripletAe = _disent.frameworks.ae._supervised__tae.TripletAe - Ae = _disent.frameworks.ae._unsupervised__ae.Ae - # [VAE] - TripletVae = _disent.frameworks.vae._supervised__tvae.TripletVae - BetaTcVae = _disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae - BetaVae = _disent.frameworks.vae._unsupervised__betavae.BetaVae - DfcVae = _disent.frameworks.vae._unsupervised__dfcvae.DfcVae - DipVae = _disent.frameworks.vae._unsupervised__dipvae.DipVae - InfoVae = _disent.frameworks.vae._unsupervised__infovae.InfoVae - Vae = _disent.frameworks.vae._unsupervised__vae.Vae - AdaVae = _disent.frameworks.vae._weaklysupervised__adavae.AdaVae - - -# changes here should also update `disent/frameworks/helper/reconstructions.py` -class RECON_LOSS(metaclass=_R.LazyImportMeta(to_lowercase=True)): - # [STANDARD LOSSES] - Mse = _disent.frameworks.helper.reconstructions.ReconLossHandlerMse # from the normal distribution - real values in the range [0, 1] - Mae = _disent.frameworks.helper.reconstructions.ReconLossHandlerMae # mean absolute error - # [STANDARD DISTRIBUTIONS] - Bce = _disent.frameworks.helper.reconstructions.ReconLossHandlerBce # from the bernoulli distribution - binary values in the set {0, 1} - Bernoulli = _disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli # reduces to bce - binary values in the set {0, 1} - ContinuousBernoulli = _disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli # bernoulli with a computed offset to handle values in the range [0, 1] - Normal = _disent.frameworks.helper.reconstructions.ReconLossHandlerNormal # handle all real values - - -# changes here should also update `disent/frameworks/helper/latent_distributions.py` -class LATENT_DIST(metaclass=_R.LazyImportMeta()): - Normal = _disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal - Laplace = _disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace - - -# non-disent classes -class OPTIMIZER(metaclass=_R.LazyImportMeta()): - # [torch] - Adadelta = _torch.optim.adadelta.Adadelta - Adagrad = _torch.optim.adagrad.Adagrad - Adam = _torch.optim.adam.Adam - Adamax = _torch.optim.adamax.Adamax - AdamW = _torch.optim.adamw.AdamW - ASGD = _torch.optim.asgd.ASGD - LBFGS = _torch.optim.lbfgs.LBFGS - RMSprop = _torch.optim.rmsprop.RMSprop - Rprop = _torch.optim.rprop.Rprop - SGD = _torch.optim.sgd.SGD - SparseAdam= _torch.optim.sparse_adam.SparseAdam - # [torch_optimizer] - non-optimizers: Lookahead - A2GradExp = _torch_optimizer.A2GradExp - A2GradInc = _torch_optimizer.A2GradInc - A2GradUni = _torch_optimizer.A2GradUni - AccSGD = _torch_optimizer.AccSGD - AdaBelief = _torch_optimizer.AdaBelief - AdaBound = _torch_optimizer.AdaBound - AdaMod = _torch_optimizer.AdaMod - Adafactor = _torch_optimizer.Adafactor - Adahessian= _torch_optimizer.Adahessian - AdamP = _torch_optimizer.AdamP - AggMo = _torch_optimizer.AggMo - Apollo = _torch_optimizer.Apollo - DiffGrad = _torch_optimizer.DiffGrad - Lamb = _torch_optimizer.Lamb - NovoGrad = _torch_optimizer.NovoGrad - PID = _torch_optimizer.PID - QHAdam = _torch_optimizer.QHAdam - QHM = _torch_optimizer.QHM - RAdam = _torch_optimizer.RAdam - Ranger = _torch_optimizer.Ranger - RangerQH = _torch_optimizer.RangerQH - RangerVA = _torch_optimizer.RangerVA - SGDP = _torch_optimizer.SGDP - SGDW = _torch_optimizer.SGDW - SWATS = _torch_optimizer.SWATS - Shampoo = _torch_optimizer.Shampoo - Yogi = _torch_optimizer.Yogi - - -# changes here should also update `disent/metrics/__init__.py` -class METRIC(metaclass=_R.LazyImportMeta()): - dci = _disent.metrics._dci.metric_dci - factor_vae = _disent.metrics._factor_vae.metric_factor_vae - mig = _disent.metrics._mig.metric_mig - sap = _disent.metrics._sap.metric_sap - unsupervised = _disent.metrics._unsupervised.metric_unsupervised - - -# changes here should also update `disent/schedule/__init__.py` -class SCHEDULE(metaclass=_R.LazyImportMeta()): - Clip = _disent.schedule._schedule.ClipSchedule - CosineWave = _disent.schedule._schedule.CosineWaveSchedule - Cyclic = _disent.schedule._schedule.CyclicSchedule - Linear = _disent.schedule._schedule.LinearSchedule - Noop = _disent.schedule._schedule.NoopSchedule - - -# changes here should also update `disent/model/ae/__init__.py` -class MODEL(metaclass=_R.LazyImportMeta()): - # [DECODER] - EncoderConv64 = _disent.model.ae._vae_conv64.EncoderConv64 - EncoderConv64Norm = _disent.model.ae._norm_conv64.EncoderConv64Norm - EncoderFC = _disent.model.ae._vae_fc.EncoderFC - EncoderTest = _disent.model.ae._test.EncoderTest - # [ENCODER] - DecoderConv64 = _disent.model.ae._vae_conv64.DecoderConv64 - DecoderConv64Norm = _disent.model.ae._norm_conv64.DecoderConv64Norm - DecoderFC = _disent.model.ae._vae_fc.DecoderFC - DecoderTest = _disent.model.ae._test.DecoderTest +# ========================================================================= # +# DATASETS - should be synchronized with: `disent/dataset/data/__init__.py` # +# ========================================================================= # + + +# TODO: this is not yet used in disent.data or disent.frameworks +DATASETS = _Registry('DATASETS') +# groundtruth -- impl +DATASETS['cars3d'] = _LazyImport('disent.dataset.data._groundtruth__cars3d') +DATASETS['dsprites'] = _LazyImport('disent.dataset.data._groundtruth__dsprites') +DATASETS['mpi3d'] = _LazyImport('disent.dataset.data._groundtruth__mpi3d') +DATASETS['smallnorb'] = _LazyImport('disent.dataset.data._groundtruth__norb') +DATASETS['shapes3d'] = _LazyImport('disent.dataset.data._groundtruth__shapes3d') +# groundtruth -- impl synthetic +DATASETS['xyobject'] = _LazyImport('disent.dataset.data._groundtruth__xyobject') # ========================================================================= # -# Registry of all Registries # +# SAMPLERS - should be synchronized with: # +# `disent/dataset/sampling/__init__.py` # +# ========================================================================= # + + +# TODO: this is not yet used in disent.data or disent.frameworks +# changes here should also update +SAMPLERS = _Registry('SAMPLERS') +# [ground truth samplers] +SAMPLERS['gt_dist'] = _LazyImport('disent.dataset.sampling._groundtruth__dist.GroundTruthDistSampler') +SAMPLERS['gt_pair'] = _LazyImport('disent.dataset.sampling._groundtruth__pair.GroundTruthPairSampler') +SAMPLERS['gt_pair_orig'] = _LazyImport('disent.dataset.sampling._groundtruth__pair_orig.GroundTruthPairOrigSampler') +SAMPLERS['gt_single'] = _LazyImport('disent.dataset.sampling._groundtruth__single.GroundTruthSingleSampler') +SAMPLERS['gt_triple'] = _LazyImport('disent.dataset.sampling._groundtruth__triplet.GroundTruthTripleSampler') +# [any dataset samplers] +SAMPLERS['single'] = _LazyImport('disent.dataset.sampling._single.SingleSampler') +SAMPLERS['random'] = _LazyImport('disent.dataset.sampling._random__any.RandomSampler') +# [episode samplers] +SAMPLERS['random_episode'] = _LazyImport('disent.dataset.sampling._random__episodes.RandomEpisodeSampler') + + +# ========================================================================= # +# FRAMEWORKS - should be synchronized with: # +# `disent/frameworks/ae/__init__.py` # +# `disent/frameworks/ae/experimental/__init__.py` # +# `disent/frameworks/vae/__init__.py` # +# `disent/frameworks/vae/experimental/__init__.py` # +# ========================================================================= # + + +# TODO: this is not yet used in disent.frameworks +FRAMEWORKS = _Registry('FRAMEWORKS') +# [AE] +FRAMEWORKS['tae'] = _LazyImport('disent.frameworks.ae._supervised__tae.TripletAe') +FRAMEWORKS['ae'] = _LazyImport('disent.frameworks.ae._unsupervised__ae.Ae') +# [VAE] +FRAMEWORKS['tvae'] = _LazyImport('disent.frameworks.vae._supervised__tvae.TripletVae') +FRAMEWORKS['betatc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betatcvae.BetaTcVae') +FRAMEWORKS['beta_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__betavae.BetaVae') +FRAMEWORKS['dfc_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dfcvae.DfcVae') +FRAMEWORKS['dip_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__dipvae.DipVae') +FRAMEWORKS['info_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae') +FRAMEWORKS['vae'] = _LazyImport('disent.frameworks.vae._unsupervised__vae.Vae') +FRAMEWORKS['ada_vae'] = _LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae') + + +# ========================================================================= # +# RECON_LOSSES - should be synchronized with: # +# `disent/frameworks/helper/reconstructions.py` # +# ========================================================================= # + + +RECON_LOSSES = _Registry('RECON_LOSSES') +# [STANDARD LOSSES] +RECON_LOSSES['mse'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMse') # from the normal distribution - real values in the range [0, 1] +RECON_LOSSES['mae'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerMae') # mean absolute error +# [STANDARD DISTRIBUTIONS] +RECON_LOSSES['bce'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBce') # from the bernoulli distribution - binary values in the set {0, 1} +RECON_LOSSES['bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerBernoulli') # reduces to bce - binary values in the set {0, 1} +RECON_LOSSES['c_bernoulli'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerContinuousBernoulli') # bernoulli with a computed offset to handle values in the range [0, 1] +RECON_LOSSES['normal'] = _LazyImport('disent.frameworks.helper.reconstructions.ReconLossHandlerNormal') # handle all real values + + +# ========================================================================= # +# LATENT_DISTS - should be synchronized with: # +# `disent/frameworks/helper/latent_distributions.py` # +# ========================================================================= # + + +# TODO: this is not yet used in disent.frameworks or disent.frameworks.helper.latent_distributions +LATENT_DISTS = _Registry('LATENT_DISTS') +LATENT_DISTS['normal'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerNormal') +LATENT_DISTS['laplace'] = _LazyImport('disent.frameworks.helper.latent_distributions.LatentDistsHandlerLaplace') + + +# ========================================================================= # +# OPTIMIZER # +# ========================================================================= # + + +# default learning rate for each optimizer +_LR = 1e-3 + + +OPTIMIZERS = _Registry('OPTIMIZERS') +# [torch] +OPTIMIZERS['adadelta'] = _LazyImport(lr=_LR, import_path='torch.optim.adadelta.Adadelta') +OPTIMIZERS['adagrad'] = _LazyImport(lr=_LR, import_path='torch.optim.adagrad.Adagrad') +OPTIMIZERS['adam'] = _LazyImport(lr=_LR, import_path='torch.optim.adam.Adam') +OPTIMIZERS['adamax'] = _LazyImport(lr=_LR, import_path='torch.optim.adamax.Adamax') +OPTIMIZERS['adam_w'] = _LazyImport(lr=_LR, import_path='torch.optim.adamw.AdamW') +OPTIMIZERS['asgd'] = _LazyImport(lr=_LR, import_path='torch.optim.asgd.ASGD') +OPTIMIZERS['lbfgs'] = _LazyImport(lr=_LR, import_path='torch.optim.lbfgs.LBFGS') +OPTIMIZERS['rmsprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rmsprop.RMSprop') +OPTIMIZERS['rprop'] = _LazyImport(lr=_LR, import_path='torch.optim.rprop.Rprop') +OPTIMIZERS['sgd'] = _LazyImport(lr=_LR, import_path='torch.optim.sgd.SGD') +OPTIMIZERS['sparse_adam'] = _LazyImport(lr=_LR, import_path='torch.optim.sparse_adam.SparseAdam') +# [torch_optimizer] +OPTIMIZERS['acc_sgd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AccSGD') +OPTIMIZERS['ada_bound'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaBound') +OPTIMIZERS['ada_mod'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdaMod') +OPTIMIZERS['adam_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AdamP') +OPTIMIZERS['agg_mo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.AggMo') +OPTIMIZERS['diff_grad'] = _LazyImport(lr=_LR, import_path='torch_optimizer.DiffGrad') +OPTIMIZERS['lamb'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Lamb') +# 'torch_optimizer.Lookahead' is skipped because it is wrapped +OPTIMIZERS['novograd'] = _LazyImport(lr=_LR, import_path='torch_optimizer.NovoGrad') +OPTIMIZERS['pid'] = _LazyImport(lr=_LR, import_path='torch_optimizer.PID') +OPTIMIZERS['qh_adam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHAdam') +OPTIMIZERS['qhm'] = _LazyImport(lr=_LR, import_path='torch_optimizer.QHM') +OPTIMIZERS['radam'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RAdam') +OPTIMIZERS['ranger'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Ranger') +OPTIMIZERS['ranger_qh'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerQH') +OPTIMIZERS['ranger_va'] = _LazyImport(lr=_LR, import_path='torch_optimizer.RangerVA') +OPTIMIZERS['sgd_w'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDW') +OPTIMIZERS['sgd_p'] = _LazyImport(lr=_LR, import_path='torch_optimizer.SGDP') +OPTIMIZERS['shampoo'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Shampoo') +OPTIMIZERS['yogi'] = _LazyImport(lr=_LR, import_path='torch_optimizer.Yogi') + + +# ========================================================================= # +# METRIC - should be synchronized with: `disent/metrics/__init__.py` # +# ========================================================================= # + + +# TODO: this is not yet used in disent.util.lightning.callbacks or disent.metrics +METRICS = _Registry('METRICS') +METRICS['dci'] = _LazyImport('disent.metrics._dci.metric_dci') +METRICS['factor_vae'] = _LazyImport('disent.metrics._factor_vae.metric_factor_vae') +METRICS['mig'] = _LazyImport('disent.metrics._mig.metric_mig') +METRICS['sap'] = _LazyImport('disent.metrics._sap.metric_sap') +METRICS['unsupervised'] = _LazyImport('disent.metrics._unsupervised.metric_unsupervised') + + +# ========================================================================= # +# SCHEDULE - should be synchronized with: `disent/schedule/__init__.py` # # ========================================================================= # -# self-reference -- for testing purposes -class REGISTRY(metaclass=_R.LazyImportMeta()): - DATASET = _disent.registry.DATASET - SAMPLER = _disent.registry.SAMPLER - FRAMEWORK = _disent.registry.FRAMEWORK - RECON_LOSS = _disent.registry.RECON_LOSS - LATENT_DIST = _disent.registry.LATENT_DIST - OPTIMIZER = _disent.registry.OPTIMIZER - METRIC = _disent.registry.METRIC - SCHEDULE = _disent.registry.SCHEDULE - MODEL = _disent.registry.MODEL +# TODO: this is not yet used in disent.framework or disent.schedule +SCHEDULES = _Registry('SCHEDULES') +SCHEDULES['clip'] = _LazyImport('disent.schedule._schedule.ClipSchedule') +SCHEDULES['cosine_wave'] = _LazyImport('disent.schedule._schedule.CosineWaveSchedule') +SCHEDULES['cyclic'] = _LazyImport('disent.schedule._schedule.CyclicSchedule') +SCHEDULES['linear'] = _LazyImport('disent.schedule._schedule.LinearSchedule') +SCHEDULES['noop'] = _LazyImport('disent.schedule._schedule.NoopSchedule') # ========================================================================= # -# cleanup # +# MODEL - should be synchronized with: `disent/model/ae/__init__.py` # # ========================================================================= # -del _disent -del _torch -del _torch_optimizer +# TODO: this is not yet used in disent.framework or disent.model +MODELS = _Registry('MODELS') +# [DECODER] +MODELS['encoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.EncoderConv64') +MODELS['encoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.EncoderConv64Norm') +MODELS['encoder_fc'] = _LazyImport('disent.model.ae._vae_fc.EncoderFC') +MODELS['encoder_linear'] = _LazyImport('disent.model.ae._linear.EncoderLinear') +# [ENCODER] +MODELS['decoder_conv64'] = _LazyImport('disent.model.ae._vae_conv64.DecoderConv64') +MODELS['decoder_conv64norm'] = _LazyImport('disent.model.ae._norm_conv64.DecoderConv64Norm') +MODELS['decoder_fc'] = _LazyImport('disent.model.ae._vae_fc.DecoderFC') +MODELS['decoder_linear'] = _LazyImport('disent.model.ae._linear.DecoderLinear') + + +# ========================================================================= # +# Registry of all Registries # +# ========================================================================= # + + +# registry of registries +REGISTRIES = _Registry('REGISTRIES') +REGISTRIES['DATASETS'] = DATASETS +REGISTRIES['SAMPLERS'] = SAMPLERS +REGISTRIES['FRAMEWORKS'] = FRAMEWORKS +REGISTRIES['RECON_LOSSES'] = RECON_LOSSES +REGISTRIES['LATENT_DISTS'] = LATENT_DISTS +REGISTRIES['OPTIMIZERS'] = OPTIMIZERS +REGISTRIES['METRICS'] = METRICS +REGISTRIES['SCHEDULES'] = SCHEDULES +REGISTRIES['MODELS'] = MODELS + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/registry/_registry.py b/disent/registry/_registry.py new file mode 100644 index 00000000..dbf4d911 --- /dev/null +++ b/disent/registry/_registry.py @@ -0,0 +1,225 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from typing import Any +from typing import Callable +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypeVar + +from disent.util.function import wrapped_partial +from disent.util.imports import import_obj_partial +from disent.util.imports import _check_and_split_path + + +# ========================================================================= # +# Basic Cached Item # +# ========================================================================= # + + +T = TypeVar('T') + + +class LazyValue(object): + + def __init__(self, generate_fn: Callable[[], T]): + assert callable(generate_fn) + self._is_generated = False + self._generate_fn = generate_fn + self._value = None + + def generate(self) -> T: + # replace value -- we don't actually need caching of the + # values since the registry replaces these items automatically, + # but LazyValue is exposed so it might be used unexpectedly. + if not self._is_generated: + self._is_generated = True + self._value = self._generate_fn() + self._generate_fn = None + # return value + return self._value + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self._generate_fn)})' + + +# ========================================================================= # +# Import Helper # +# ========================================================================= # + + +class LazyImport(LazyValue): + def __init__(self, import_path: str, *partial_args, **partial_kwargs): + super().__init__( + generate_fn=lambda: import_obj_partial(import_path, *partial_args, **partial_kwargs), + ) + + +# ========================================================================= # +# Registry # +# ========================================================================= # + + +_NONE = object() + + +class Registry(object): + + def __init__( + self, + name: str, + assert_valid_key: Optional[Callable[[str], NoReturn]] = None, + assert_valid_value: Optional[Callable[[T], NoReturn]] = None, + ): + # checks! + assert str.isidentifier(name), f'The given name for the registry is not a valid identifier: {repr(name)}' + self._name = name + assert (assert_valid_key is None) or callable(assert_valid_key), f'assert_valid_key must be None or callable' + assert (assert_valid_value is None) or callable(assert_valid_value), f'assert_valid_value must be None or callable' + self._assert_valid_key = assert_valid_key + self._assert_valid_value = assert_valid_value + # storage + self._keys_to_values: Dict[str, Any] = {} + + @property + def name(self) -> str: + return self._name + + def _get_aliases(self, name, aliases, auto_alias: bool): + if auto_alias: + if name not in self: + aliases = (name, *aliases) + elif not aliases: + raise RuntimeError(f'automatic alias: {repr(name)} already exists but no alternative aliases were specified.') + return aliases + + def register( + self, + fn=_NONE, + aliases: Sequence[str] = (), + auto_alias: bool = True, + partial_args: Tuple[Any, ...] = None, + partial_kwargs: Dict[str, Any] = None, + ) -> Callable[[T], T]: + def _decorator(orig_fn): + # try add the name of the object + keys = self._get_aliases(orig_fn.__name__, aliases=aliases, auto_alias=auto_alias) + # wrap function + new_fn = orig_fn + if (partial_args is not None) or (partial_kwargs is not None): + new_fn = wrapped_partial( + orig_fn, + *(() if partial_args is None else partial_args), + **({} if partial_kwargs is None else partial_kwargs), + ) + # register the function + self.register_value(value=new_fn, aliases=keys) + return orig_fn + # handle case + if fn is _NONE: + return _decorator + else: + return _decorator(fn) + + def register_import( + self, + import_path: str, + aliases: Sequence[str] = (), + auto_alias: bool = True, + *partial_args, + **partial_kwargs, + ) -> 'Registry': + (*_, name) = _check_and_split_path(import_path) + return self.register_value( + value=LazyImport(import_path=import_path, *partial_args, **partial_kwargs), + aliases=self._get_aliases(name, aliases=aliases, auto_alias=auto_alias), + ) + + def register_value(self, value: T, aliases: Sequence[str]) -> 'Registry': + # check keys + if len(aliases) < 1: + raise ValueError(f'aliases must be specified, got an empty sequence') + for k in aliases: + if not str.isidentifier(k): + raise ValueError(f'alias is not a valid identifier: {repr(k)}') + if k in self._keys_to_values: + raise RuntimeError(f'registry already contains key: {repr(k)}') + self.assert_valid_key(k) + # handle lazy values -- defer checking if a lazy value + if not isinstance(value, LazyValue): + self.assert_valid_value(value) + # register keys + for k in aliases: + self._keys_to_values[k] = value + return self + + def __setitem__(self, aliases: str, value: T): + if isinstance(aliases, str): + aliases = (aliases,) + assert isinstance(aliases, tuple), f'multiple aliases must be provided as a Tuple[str], got: {repr(aliases)}' + self.register_value(value=value, aliases=aliases) + + def __contains__(self, key: str): + return key in self._keys_to_values + + def __getitem__(self, key: str): + if key not in self._keys_to_values: + raise KeyError(f'registry does not contain the key: {repr(key)}, valid keys include: {sorted(self._keys_to_values.keys())}') + # get the value + value = self._keys_to_values[key] + # replace lazy value + if isinstance(value, LazyValue): + value = value.generate() + # check value & run deferred checks + if isinstance(value, LazyValue): + raise RuntimeError(f'{LazyValue.__name__} should never return other lazy values.') + self.assert_valid_value(value) + # update the value + self._keys_to_values[key] = value + # return the value + return value + + def __iter__(self): + yield from self._keys_to_values.keys() + + def assert_valid_value(self, value: T) -> T: + if self._assert_valid_value is not None: + self._assert_valid_value(value) + return value + + def assert_valid_key(self, key: str) -> str: + if self._assert_valid_key is not None: + self._assert_valid_key(key) + return key + + def __repr__(self): + return f'{self.__class__.__name__}({repr(self._name)}, ...)' + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/registry/_registry_util.py b/disent/registry/_registry_util.py deleted file mode 100644 index 9331f8ee..00000000 --- a/disent/registry/_registry_util.py +++ /dev/null @@ -1,118 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -# ========================================================================= # -# Registry Helper # -# ========================================================================= # - - -class PathBuilder(object): - """ - Path builder stores the path taken down attributes - - This is used to trick pycharm type hinting. In the example - below, `Cars3dData` will be an instance of `_PathBuilder`, - but will type hint to `disent.dataset.data.Cars3dData` - ``` - disent = _PathBuilder() - if False: - import disent.dataset.data - Cars3dData = disent.dataset.data._groundtruth__cars3d.Cars3dData - ``` - """ - - def __init__(self, *segments): - self.__segments = tuple(segments) - - def __getattr__(self, item: str): - return PathBuilder(*self.__segments, item) - - def _do_import_(self): - import importlib - import_module, import_name = '.'.join(self.__segments[:-1]), self.__segments[-1] - try: - module = importlib.import_module(import_module) - except Exception as e: - raise ImportError(f'failed to import module: {repr(import_module)} ({".".join(self.__segments)})') from e - try: - obj = getattr(module, import_name) - except Exception as e: - raise ImportError(f'failed to get attribute on module: {repr(import_name)} ({".".join(self.__segments)})') from e - return obj - - -def LazyImportMeta(to_lowercase: bool = True): - """ - Lazy import paths metaclass checks for stored instances of `_PathBuilder` on a class and returns the - imported version of the attribute instead of the `_PathBuilder` itself. - - Used to perform lazy importing of classes and objects inside a module - """ - - if to_lowercase: - def transform(item): - if isinstance(item, str): - return item.lower() - return item - else: - def transform(item): - return item - - class _LazyImportMeta: - def __init__(cls, name, bases, dct): - cls.__unimported = {} # Dict[str, _PathBuilder] - cls.__imported = {} # Dict[str, Any] - # check annotations - for key, value in dct.items(): - if isinstance(value, PathBuilder): - assert str.isidentifier(key), f'registry key is not an identifier: {repr(key)}' - key = transform(key) - cls.__unimported[key] = value - - def __contains__(cls, item): - item = transform(item) - return (item in cls.__unimported) - - def __getitem__(cls, item): - item = transform(item) - if item not in cls.__imported: - if item not in cls.__unimported: - raise KeyError(f'invalid key: {repr(item)}, must be one of: {sorted(cls.__unimported.keys())}') - cls.__imported[item] = cls.__unimported[item]._do_import_() - return cls.__imported[item] - - def __getattr__(cls, item): - item = transform(item) - if item not in cls.__unimported: - raise AttributeError(f'invalid attribute: {repr(item)}, must be one of: {sorted(cls.__unimported.keys())}') - return cls[item] - - def __iter__(self): - yield from (transform(item) for item in self.__unimported.keys()) - - return _LazyImportMeta - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/schedule/__init__.py b/disent/schedule/__init__.py index 36cde1c5..c9ea8030 100644 --- a/disent/schedule/__init__.py +++ b/disent/schedule/__init__.py @@ -30,6 +30,7 @@ from ._schedule import CyclicSchedule from ._schedule import LinearSchedule from ._schedule import NoopSchedule +from ._schedule import SingleSchedule # aliases from ._schedule import ClipSchedule as Clip @@ -37,3 +38,4 @@ from ._schedule import CyclicSchedule as Cyclic from ._schedule import LinearSchedule as Linear from ._schedule import NoopSchedule as Noop +from ._schedule import SingleSchedule as Single diff --git a/disent/schedule/_schedule.py b/disent/schedule/_schedule.py index 0367833b..41db86d5 100644 --- a/disent/schedule/_schedule.py +++ b/disent/schedule/_schedule.py @@ -22,11 +22,14 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +from typing import Union +from typing import Optional + import numpy as np from disent.schedule.lerp import cyclical_anneal +from disent.schedule.lerp import lerp from disent.schedule.lerp import lerp_step -from disent.schedule.lerp import scale # ========================================================================= # @@ -60,12 +63,23 @@ def compute_value(self, step: int, value): # ========================================================================= # -def _common(value, ratio, a, b): - # scale the ratio (which should be in the range [0, 1]) between [r_min, r_max] - sratio = scale(ratio, a, b) - # multiply the value - result = value * sratio - return result +def _common_lerp_value(ratio, value, r_start: float, r_end: float): + # scale the value such that it (which should be in the range [0, 1]) between [r_min, r_max] + return lerp( + ratio, + start_val=value * r_start, + end_val=value * r_end, + ) + + +def _completion_ratio(step: int, start_step: int, end_step: int): + ratio = lerp_step( + step=(step - start_step), + max_step=(end_step - start_step), + start_val=0.0, + end_val=1.0, + ) + return ratio class LinearSchedule(Schedule): @@ -76,23 +90,32 @@ class LinearSchedule(Schedule): computed value that is in the range [0, 1] """ - def __init__(self, min_step: int, max_step: int, r_start: float = 0.0, r_end: float = 1.0): - assert max_step > 0 - assert min_step >= 0 - assert min_step < max_step - self.min_step = min_step - self.max_step = max_step + def __init__( + self, + start_step: int, + end_step: int, + r_start: float = 0.0, + r_end: float = 1.0, + ): + """ + :param min_step: The step at which the schedule starts and the value unfreezes + :param max_step: The step at which the schedule finishes and the value freezes + :param r_start: The ratio of the original value that the schedule will start with + :param r_end: The ratio of the original value that the schedule will end with + """ + assert start_step >= 0 + assert end_step > 0 + assert start_step < end_step + self.start_step = start_step + self.end_step = end_step self.r_start = r_start self.r_end = r_end def compute_value(self, step: int, value): - ratio = lerp_step( - step=(step - self.min_step), - max_step=(self.max_step - self.min_step), - a=0.0, - b=1.0, - ) - return _common(value, ratio, a=self.r_start, b=self.r_end) + # completion ratio in range [0, 1]. If step < start_step return 0, if step > end_step return 1 + ratio = _completion_ratio(step=step, start_step=self.start_step, end_step=self.end_step) + # lerp the value into the range [r_start * value, r_end * value] according to the ratio + return _common_lerp_value(ratio, value=value, r_start=self.r_start, r_end=self.r_end) class CyclicSchedule(Schedule): @@ -105,10 +128,41 @@ class CyclicSchedule(Schedule): """ # TODO: maybe move this api into cyclical_anneal - def __init__(self, period: int, repeats: int = None, r_start=0.0, r_end=1.0, end_value='end', mode='linear', p_low=0.0, p_high=0.0): + def __init__( + self, + period: int, + start_step: Optional[int] = None, + repeats: Optional[int] = None, + r_start: float = 0.0, + r_end: float = 1.0, + end_mode: str = 'end', + mode: str = 'linear', + p_low: float = 0.0, + p_high: float = 0.0, + ): + """ + :param period: The number of steps it takes for the schedule to repeat + :param start_step: The step when the schedule will start, if this is None + then no modification to the step is performed. Equivalent to + `start_step=0` if no negative step values are passed. + :param repeats: The number of repeats of this schedule. The end_step of the schedule will + be `start_step + repeats*period`. If `repeats is None` or `repeats < 0` then the + schedule never ends. + :param r_start: The ratio of the original value that the schedule will start with + :param r_end: The ratio of the original value that the schedule will end with + :param end_mode: what of value the schedule should take after finishing [start, end] + :param mode: The kind of function use to interpolate between the start and finish [linear, sigmoid, cosine] + :param p_low: The portion of the period at the start that is spent at the minimum value + :param p_high: The portion of the period that at the end is spent at the maximum value + """ + # checks + if (repeats is not None) and (repeats < 0): + repeats = None + # set values self.period = period self.repeats = repeats - self.end_value = {'start': 'low', 'end': 'high'}[end_value] + self.start_step = start_step + self.end_value = {'start': 'low', 'end': 'high'}[end_mode] self.mode = mode # scale values self.r_start = r_start @@ -116,8 +170,13 @@ def __init__(self, period: int, repeats: int = None, r_start=0.0, r_end=1.0, end # portions of low and high -- low + high <= 1.0 -- low + slope + high == 1.0 self.p_low = p_low self.p_high = p_high + # checks + assert (start_step is None) or (start_step >= 0) def compute_value(self, step: int, value): + # shift the start + if self.start_step is not None: + step = max(0, step - self.start_step) # outputs value in range [0, 1] ratio = cyclical_anneal( step=step, @@ -129,7 +188,7 @@ def compute_value(self, step: int, value): end_value=self.end_value, mode=self.mode ) - return _common(value, ratio, a=self.r_start, b=self.r_end) + return _common_lerp_value(ratio, value=value, r_start=self.r_start, r_end=self.r_end) class SingleSchedule(CyclicSchedule): @@ -141,14 +200,31 @@ class SingleSchedule(CyclicSchedule): computed value that is in the range [0, 1] """ - def __init__(self, max_step, r_start=0.0, r_end=1.0, mode='linear'): + def __init__( + self, + start_step: int, + end_step: int, + r_start: float = 0.0, + r_end: float = 1.0, + mode: str = 'linear', + ): + """ + :param start_step: The step when the schedule will start + :param end_step: The step when the schedule will finish + :param r_start: The ratio of the original value that the schedule will start with + :param r_end: The ratio of the original value that the schedule will end with + :param mode: The kind of function use to interpolate between the start and finish [linear, sigmoid, cosine] + """ super().__init__( - period=max_step, + period=(end_step - start_step), + start_step=start_step, repeats=1, r_start=r_start, r_end=r_end, - end_value='end', + end_mode='end', mode=mode, + p_low=0.0, # adjust the start and end steps instead + p_high=0.0, # adjust the start and end steps instead ) @@ -161,15 +237,29 @@ class CosineWaveSchedule(Schedule): computed value that is in the range [0, 1] """ - def __init__(self, period: int, r_start: float = 0.0, r_end: float = 1.0): + # TODO: add r_start + # TODO: add start_step + # TODO: add repeats + def __init__( + self, + period: int, + r_start: float = 0.0, + r_end: float = 1.0, + ): + """ + :param period: The number of steps it takes for the schedule to repeat + :param r_start: The ratio of the original value that the schedule will start with + :param r_end: The ratio of the original value that the schedule will end with + """ assert period > 0 self.period = period self.r_start = r_start self.r_end = r_end def compute_value(self, step: int, value): - ratio = 0.5 * (1 + np.cos(step * (2 * np.pi / self.period) + np.pi)) - return _common(value, ratio, a=self.r_start, b=self.r_end) + cosine_ratio = 0.5 * (1 + np.cos(step * (2 * np.pi / self.period) + np.pi)) + # lerp the value into the range [r_start * value, r_end * value] according to the ratio + return _common_lerp_value(cosine_ratio, value=value, r_start=self.r_start, r_end=self.r_end) # ========================================================================= # @@ -182,15 +272,36 @@ class ClipSchedule(Schedule): This schedule shifts the step, or clips the value """ - def __init__(self, schedule: Schedule, min_step=None, max_step=None, shift_step=True, min_value=None, max_value=None): + def __init__( + self, + schedule: Schedule, + min_step: Optional[int] = None, + max_step: Optional[int] = None, + shift_step: Union[bool, int] = True, + min_value: Optional[float] = None, + max_value: Optional[float] = None, + ): + """ + :param schedule: + :param min_step: The minimum step passed to the sub-schedule + :param max_step: The maximum step passed to the sub-schedule + :param shift_step: (if bool) Shift all the step values passed to the sub-schedule, + at or before min_step the sub-schedule will get `0`, at or after + max_step the sub-schedule will get `max_step-shift_step` + (if int) Add the given value to the step passed to the sub-schedule + :param min_value: The minimum value returned from the sub-schedule + :param max_value: The maximum value returned from the sub-schedule + """ assert isinstance(schedule, Schedule) self.schedule = schedule # step settings - self.min_step = min_step if (min_step is not None) else 0 + self.min_step = min_step self.max_step = max_step - if isinstance(shift_step, bool): - shift_step = (-self.min_step) if shift_step else None + # shift step self.shift_step = shift_step + if isinstance(shift_step, bool): + if self.min_step is not None: + self.shift_step = -self.min_step # value settings self.min_value = min_value self.max_value = max_value @@ -208,3 +319,47 @@ def compute_value(self, step: int, value): # ========================================================================= # # END # # ========================================================================= # + + +if __name__ == '__main__': + + def plot_schedules(*schedules: Schedule, total: int = 1000, value=1): + import matplotlib.pyplot as plt + fig, axs = plt.subplots(1, len(schedules), figsize=(3*len(schedules), 3)) + for ax, s in zip(axs, schedules): + xs = list(range(total)) + ys = [s(x, value) for x in xs] + ax.set_xlim([-0.05 * total, total + 0.05 * total]) + ax.set_ylim([-0.05 * value, value + 0.05 * value]) + ax.plot(xs, ys) + ax.set_title(f'{s.__class__.__name__}') + fig.tight_layout() + plt.show() + + def main(): + + # these should be equivalent + plot_schedules( + LinearSchedule(start_step=100, end_step=900, r_start=0.1, r_end=0.8), + LinearSchedule(start_step=200, end_step=800, r_start=0.9, r_end=0.2), + SingleSchedule(start_step=100, end_step=900, r_start=0.1, r_end=0.8), + SingleSchedule(start_step=200, end_step=800, r_start=0.9, r_end=0.2), + # LinearSchedule(min_step=900, max_step=100, r_start=0.1, r_end=0.8), # INVALID + # LinearSchedule(min_step=900, max_step=100, r_start=0.8, r_end=0.1), # INVALID + ) + + plot_schedules( + CyclicSchedule(period=300, start_step=0, repeats=None, r_start=0.1, r_end=0.8, end_mode='end', mode='linear', p_low=0.00, p_high=0.00), + CyclicSchedule(period=300, start_step=0, repeats=2, r_start=0.9, r_end=0.2, end_mode='start', mode='linear', p_low=0.25, p_high=0.00), + CyclicSchedule(period=300, start_step=200, repeats=2, r_start=0.9, r_end=0.2, end_mode='end', mode='linear', p_low=0.00, p_high=0.25), + CyclicSchedule(period=300, start_step=0, repeats=2, r_start=0.1, r_end=0.8, end_mode='end', mode='cosine', p_low=0.25, p_high=0.25), + CyclicSchedule(period=300, start_step=250, repeats=None, r_start=0.1, r_end=0.8, end_mode='end', mode='sigmoid', p_low=0.00, p_high=0.00), + ) + + plot_schedules( + SingleSchedule(start_step=0, end_step=800, r_start=0.1, r_end=0.8, mode='linear'), + SingleSchedule(start_step=100, end_step=800, r_start=0.8, r_end=0.1, mode='linear'), + SingleSchedule(start_step=100, end_step=800, r_start=0.8, r_end=0.1, mode='linear'), + ) + + main() diff --git a/disent/schedule/lerp.py b/disent/schedule/lerp.py index 5803c785..8e03ec25 100644 --- a/disent/schedule/lerp.py +++ b/disent/schedule/lerp.py @@ -33,22 +33,19 @@ # ========================================================================= # -def scale(r, a, b): - return a + r * (b - a) - - -def lerp(r, a, b): +def lerp(ratio, start_val, end_val): """Linear interpolation between parameters, respects bounds when t is out of bounds [0, 1]""" # assert a < b - r = np.clip(r, 0., 1.) + r = np.clip(ratio, 0., 1.) # precise method, guarantees v==b when t==1 | simplifies to: a + t*(b-a) - return (1 - r) * a + r * b + return (1 - r) * start_val + r * end_val + # return start_val + r * (end_val - start_val) # EQUIVALENT -def lerp_step(step, max_step, a, b): +def lerp_step(step, max_step, start_val, end_val): """Linear interpolation based on a step count.""" assert max_step > 0 - return lerp(step / max_step, a, b) + return lerp(ratio=step / max_step, start_val=start_val, end_val=end_val) # ========================================================================= # @@ -85,12 +82,12 @@ def scale_ratio(r, mode='linear'): def cyclical_anneal( step: Union[int, float, np.ndarray], period: float = 3600, - low_ratio=0.0, - high_ratio=0.0, + low_ratio: float = 0.0, + high_ratio: float = 0.0, repeats: int = None, - start_low=True, - end_value='high', - mode='linear', + start_low: bool = True, + end_value: str = 'high', + mode: str = 'linear', ): # check values assert 0 <= low_ratio <= 1 diff --git a/disent/util/array.py b/disent/util/array.py new file mode 100644 index 00000000..3f92bc02 --- /dev/null +++ b/disent/util/array.py @@ -0,0 +1,55 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np +from wandb.wandb_torch import torch + + +# ========================================================================= # +# DEBUG # +# ========================================================================= # + + +def replace_arrays_with_shapes(obj): + """ + recursively replace all arrays of an object + with their shapes to make debugging easier! + """ + if isinstance(obj, (torch.Tensor, np.ndarray)): + return obj.shape + elif isinstance(obj, dict): + return {replace_arrays_with_shapes(k): replace_arrays_with_shapes(v) for k, v in obj.items()} + elif isinstance(obj, list): + return list(replace_arrays_with_shapes(v) for v in obj) + elif isinstance(obj, tuple): + return tuple(replace_arrays_with_shapes(v) for v in obj) + elif isinstance(obj, set): + return {replace_arrays_with_shapes(k) for k in obj} + else: + return obj + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/deprecate.py b/disent/util/deprecate.py index 97d8dc11..96fc654c 100644 --- a/disent/util/deprecate.py +++ b/disent/util/deprecate.py @@ -24,6 +24,7 @@ import logging from functools import wraps +from typing import Optional # ========================================================================= # @@ -31,13 +32,55 @@ # ========================================================================= # -def deprecated(msg: str): +def _get_traceback_string() -> str: + from io import StringIO + import traceback + # print the stack trace to an in-memory buffer + file = StringIO() + traceback.print_stack(file=file) + return file.getvalue() + + +def _get_traceback_file_groups(): + # filter the lines + results = [] + group = [] + for line in _get_traceback_string().splitlines(): + if line.strip().startswith('File "'): + if group: + results.append(group) + group = [] + group.append(line) + if group: + results.append(group) + return results + + +def _get_stack_file_strings(): + # mimic the output of a traceback so pycharm performs syntax highlighting when printed + import inspect + results = [] + for frame_info in inspect.stack(): + results.append(f'File "{frame_info.filename}", line {frame_info.lineno}, in {frame_info.function}') + return results[::-1] + + +_TRACEBACK_MODES = {'none', 'first', 'mini', 'traceback'} +DEFAULT_TRACEBACK_MODE = 'first' + + +def deprecated(msg: str, traceback_mode: Optional[str] = None): """ Mark a function or class as deprecated, and print a warning the first time it is used. - This decorator wraps functions, but only replaces the __init__ method of a class so that we can still inherit from a deprecated class! """ + assert isinstance(msg, str), f'msg must be a str, got type: {type(msg)}' + if traceback_mode is None: + traceback_mode = DEFAULT_TRACEBACK_MODE + assert traceback_mode in _TRACEBACK_MODES, f'invalid traceback_mode, got: {repr(traceback_mode)}, must be one of: {sorted(_TRACEBACK_MODES)}' + def _decorator(fn): # we need to handle classes and function separately is_class = isinstance(fn, type) and hasattr(fn, '__init__') @@ -51,8 +94,19 @@ def _caller(*args, **kwargs): # print the message! if dat is not None: name, path, dsc = dat - logging.getLogger(name).warning(f'[DEPRECATED] {path} - {repr(dsc)}') + logger = logging.getLogger(name) + logger.warning(f'[DEPRECATED] {path} - {repr(dsc)}') + # get stack trace lines + if traceback_mode == 'first': lines = _get_stack_file_strings()[-3:-2] + elif traceback_mode == 'mini': lines = _get_stack_file_strings()[:-2] + elif traceback_mode == 'traceback': lines = (l[2:] for g in _get_traceback_file_groups()[:-3] for l in g) + else: lines = [] + # print lines + for line in lines: + logger.warning(f'| {line}') + # never run this again dat = None + # call the function return call_fn(*args, **kwargs) # handle function or class if is_class: diff --git a/disent/util/imports.py b/disent/util/imports.py new file mode 100644 index 00000000..cf0b9f6b --- /dev/null +++ b/disent/util/imports.py @@ -0,0 +1,75 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + + +from typing import Tuple + + +# ========================================================================= # +# Import Helper # +# ========================================================================= # + + +def _check_and_split_path(import_path: str) -> Tuple[str, ...]: + segments = import_path.split('.') + # make sure each segment is a valid python identifier + if not all(map(str.isidentifier, segments)): + raise ValueError(f'import path is invalid: {repr(import_path)}') + # return the segments! + return tuple(segments) + + +def import_obj(import_path: str): + # checks + segments = _check_and_split_path(import_path) + # split path + module_path, attr_name = '.'.join(segments[:-1]), segments[-1] + # import the module + import importlib + try: + module = importlib.import_module(module_path) + except Exception as e: + raise ImportError(f'failed to import module: {repr(module_path)}') from e + # import the attrs + try: + attr = getattr(module, attr_name) + except Exception as e: + raise ImportError(f'failed to get attribute: {repr(attr_name)} on module: {repr(module_path)}') from e + # done + return attr + + +def import_obj_partial(import_path: str, *partial_args, **partial_kwargs): + obj = import_obj(import_path) + # wrap the object if needed + if partial_args or partial_kwargs: + from disent.util.function import wrapped_partial + obj = wrapped_partial(obj, *partial_args, **partial_kwargs) + # done! + return obj + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/inout/files.py b/disent/util/inout/files.py index e27d6ebc..7b1cbf2e 100644 --- a/disent/util/inout/files.py +++ b/disent/util/inout/files.py @@ -24,7 +24,9 @@ import logging import os +from pathlib import Path from typing import Optional +from typing import Union from uuid import uuid4 from disent.util.inout.paths import uri_parse_file_or_url @@ -58,16 +60,15 @@ class AtomicSaveFile(object): def __init__( self, - file: str, + file: Union[str, Path], open_mode: Optional[str] = None, overwrite: bool = False, makedirs: bool = True, tmp_prefix: Optional[str] = '.temp.', tmp_suffix: Optional[str] = None, ): - from pathlib import Path # check files - if not file: + if not file or not Path(file).name: raise ValueError(f'file must not be empty: {repr(file)}') # get files self.trg_file = Path(file).absolute() diff --git a/disent/util/inout/hashing.py b/disent/util/inout/hashing.py index 3ad3ae7e..762e1a9f 100644 --- a/disent/util/inout/hashing.py +++ b/disent/util/inout/hashing.py @@ -111,7 +111,11 @@ def normalise_hash(hash: Union[str, Dict[str, str]], hash_mode: str) -> str: - Allow hashes to be dictionaries that map the hash_mode to the hash. This function returns the correct hash if it is a dictionary. """ - return hash[hash_mode] if isinstance(hash, dict) else hash + if isinstance(hash, dict): + if hash_mode not in hash: + raise KeyError(f'hash dictionary does not contain a key for the specified mode: {repr(hash_mode)}, available hashes are: {hash}') + return hash[hash_mode] + return hash def validate_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: str = 'md5', hash_mode: str = 'full', missing_ok=True): @@ -137,4 +141,3 @@ def is_valid_file_hash(file: str, hash: Union[str, Dict[str, str]], hash_type: s # ========================================================================= # # file hashing # # ========================================================================= # - diff --git a/disent/util/lightning/callbacks/__init__.py b/disent/util/lightning/callbacks/__init__.py index 23ea1019..60b1dff4 100644 --- a/disent/util/lightning/callbacks/__init__.py +++ b/disent/util/lightning/callbacks/__init__.py @@ -27,5 +27,6 @@ from disent.util.lightning.callbacks._callbacks_pl import LoggerProgressCallback -from disent.util.lightning.callbacks._callbacks_vae import VaeDisentanglementLoggingCallback +from disent.util.lightning.callbacks._callbacks_vae import VaeMetricLoggingCallback from disent.util.lightning.callbacks._callbacks_vae import VaeLatentCycleLoggingCallback +from disent.util.lightning.callbacks._callbacks_vae import VaeGtDistsLoggingCallback diff --git a/disent/util/lightning/callbacks/_callbacks_pl.py b/disent/util/lightning/callbacks/_callbacks_pl.py index 006bbdc4..5e8086aa 100644 --- a/disent/util/lightning/callbacks/_callbacks_pl.py +++ b/disent/util/lightning/callbacks/_callbacks_pl.py @@ -56,7 +56,8 @@ def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, curren if hasattr(trainer, 'batch_idx'): batch = (trainer.batch_idx + 1) else: - warnings.warn('batch_idx missing on pl.Trainer') + # TODO: re-enable this warning but only ever print once! + # warnings.warn('batch_idx missing on pl.Trainer') batch = global_step % max_batches # might not be int? # completion train_pct = global_step / max_steps diff --git a/disent/util/lightning/callbacks/_callbacks_vae.py b/disent/util/lightning/callbacks/_callbacks_vae.py index 37e311a9..f033a1c6 100644 --- a/disent/util/lightning/callbacks/_callbacks_vae.py +++ b/disent/util/lightning/callbacks/_callbacks_vae.py @@ -24,7 +24,12 @@ import logging import warnings +from typing import Callable +from typing import List from typing import Literal +from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Union import numpy as np @@ -32,19 +37,26 @@ import torch from pytorch_lightning.trainer.supporters import CombinedLoader +from torch.utils.data.dataloader import default_collate +from tqdm import tqdm import disent.metrics import disent.util.strings.colors as c from disent.dataset import DisentDataset from disent.dataset.data import GroundTruthData from disent.frameworks.ae import Ae +from disent.frameworks.helper.reconstructions import make_reconstruction_loss +from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.vae import Vae +from disent.util.function import wrapped_partial +from disent.util.iters import chunked from disent.util.lightning.callbacks._callbacks_base import BaseCallbackPeriodic from disent.util.lightning.logger_util import log_metrics from disent.util.lightning.logger_util import wb_log_metrics from disent.util.lightning.logger_util import wb_log_reduced_summaries from disent.util.profiling import Timer from disent.util.seeds import TempNumpySeed +from disent.util.visualize.plot import plt_subplots_imshow from disent.util.visualize.vis_model import latent_cycle_grid_animation from disent.util.visualize.vis_util import make_image_grid @@ -90,68 +102,401 @@ def _get_dataset_and_vae(trainer: pl.Trainer, pl_module: pl.LightningModule, unw # Vae Framework Callbacks # # ========================================================================= # +# helper +def _to_dmat( + size: int, + i_a: np.ndarray, + i_b: np.ndarray, + dists: Union[torch.Tensor, np.ndarray], +) -> np.ndarray: + if isinstance(dists, torch.Tensor): + dists = dists.detach().cpu().numpy() + # checks + assert i_a.ndim == 1 + assert i_a.shape == i_b.shape + assert i_a.shape == dists.shape + # compute + dmat = np.zeros([size, size], dtype='float32') + dmat[i_a, i_b] = dists + dmat[i_b, i_a] = dists + return dmat + + +_AE_DIST_NAMES = ('x', 'z_l1', 'x_recon') +_VAE_DIST_NAMES = ('x', 'z_l1', 'kl', 'x_recon') + + +@torch.no_grad() +def _get_dists_ae(ae: Ae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor): + # feed forware + z_a, z_b = ae.encode(x_a), ae.encode(x_b) + r_a, r_b = ae.decode(z_a), ae.decode(z_b) + # distances + return [ + recon_loss.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist + recon_loss.compute_pairwise_loss(r_a, r_b), + ] + + +@torch.no_grad() +def _get_dists_vae(vae: Vae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor): + from torch.distributions import kl_divergence + # feed forward + (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b) + z_a, z_b = z_post_a.mean, z_post_b.mean + r_a, r_b = vae.decode(z_a), vae.decode(z_b) + # dists + kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a) + # distances + return [ + recon_loss.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist + recon_loss._pairwise_reduce(kl_ab), + recon_loss.compute_pairwise_loss(r_a, r_b), + ] + + +def _get_dists_fn(model, recon_loss: ReconLossHandler) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]: + # get aggregate function + if isinstance(model, Vae): + dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model, recon_loss) + elif isinstance(model, Ae): + dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model, recon_loss) + else: + dists_names, dists_fn = None, None + return dists_names, dists_fn + + +@torch.no_grad() +def _collect_dists_subbatches(dists_fn: Callable[[object, object], Sequence[Sequence[float]]], batch: torch.Tensor, i_a: np.ndarray, i_b: np.ndarray, batch_size: int = 64): + # feed forward + results = [] + for idxs in chunked(np.stack([i_a, i_b], axis=-1), chunk_size=batch_size): + ia, ib = idxs.T + x_a, x_b = batch[ia], batch[ib] + # feed forward + data = dists_fn(x_a, x_b) + results.append(data) + return [torch.cat(r, dim=0) for r in zip(*results)] + + +def _compute_and_collect_dists( + dataset: DisentDataset, + dists_fn, + dists_names: Sequence[str], + traversal_repeats: int = 100, + batch_size: int = 32, + include_gt_factor_dists: bool = True, + transform_batch: Callable[[object], object] = None, + data_mode: str = 'input', +) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: + assert traversal_repeats > 0 + gt_data = dataset.gt_data + # generate + f_grid = [] + # generate + for f_idx, f_size in enumerate(gt_data.factor_sizes): + # save for the current factor (traversal_repeats, len(names), len(i_a)) + f_dists = [] + # upper triangle excluding diagonal + i_a, i_b = np.triu_indices(f_size, k=1) + # repeat over random traversals + for i in range(traversal_repeats): + # get random factor traversal + factors = gt_data.sample_random_factor_traversal(f_idx=f_idx) + indices = gt_data.pos_to_idx(factors) + # load data + batch = dataset.dataset_batch_from_indices(indices, data_mode) + if transform_batch is not None: + batch = transform_batch(batch) + # feed forward & compute dists -- (len(names), len(i_a)) + dists = _collect_dists_subbatches(dists_fn=dists_fn, batch=batch, i_a=i_a, i_b=i_b, batch_size=batch_size) + assert len(dists) == len(dists_names) + # distances + f_dists.append(dists) + # aggregate all dists into distances matrices for current factor + f_dmats = [ + _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=torch.stack(dists, dim=0).mean(dim=0)) + for dists in zip(*f_dists) + ] + # handle factors + if include_gt_factor_dists: + i_dmat = _to_dmat(size=f_size, i_a=i_a, i_b=i_b, dists=np.abs(factors[i_a] - factors[i_b]).sum(axis=-1)) + f_dmats = [i_dmat, *f_dmats] + # append data + f_grid.append(f_dmats) + # handle factors + if include_gt_factor_dists: + dists_names = ('factors', *dists_names) + # done + return tuple(dists_names), f_grid + + +def compute_factor_distances( + dataset: DisentDataset, + dists_fn, + dists_names: Sequence[str], + traversal_repeats: int = 100, + batch_size: int = 32, + include_gt_factor_dists: bool = True, + transform_batch: Callable[[object], object] = None, + seed: Optional[int] = 777, + data_mode: str = 'input', +) -> Tuple[Tuple[str, ...], List[List[np.ndarray]]]: + # log this callback + gt_data = dataset.gt_data + log.info(f'| {gt_data.name} - computing factor distances...') + # compute various distances matrices for each factor + with Timer() as timer, TempNumpySeed(seed): + dists_names, f_grid = _compute_and_collect_dists( + dataset=dataset, + dists_fn=dists_fn, + dists_names=dists_names, + traversal_repeats=traversal_repeats, + batch_size=batch_size, + include_gt_factor_dists=include_gt_factor_dists, + transform_batch=transform_batch, + data_mode=data_mode, + ) + # log this callback! + log.info(f'| {gt_data.name} - computed factor distances! time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST}') + return dists_names, f_grid + + +def plt_factor_distances( + gt_data: GroundTruthData, + f_grid: List[List[np.ndarray]], + dists_names: Sequence[str], + title: str, + plt_block_size: float = 1.25, + plt_transpose: bool = False, + plt_cmap='Blues', +): + # plot information + imshow_kwargs = dict(cmap=plt_cmap) + figsize = (plt_block_size*len(f_grid[0]), plt_block_size * gt_data.num_factors) + # plot! + if not plt_transpose: + fig, axs = plt_subplots_imshow(grid=f_grid, col_labels=dists_names, row_labels=gt_data.factor_names, figsize=figsize, title=title, imshow_kwargs=imshow_kwargs) + else: + fig, axs = plt_subplots_imshow(grid=list(zip(*f_grid)), col_labels=gt_data.factor_names, row_labels=dists_names, figsize=figsize[::-1], title=title, imshow_kwargs=imshow_kwargs) + # done + return fig, axs + + +class VaeGtDistsLoggingCallback(BaseCallbackPeriodic): + + def __init__( + self, + seed: Optional[int] = 7777, + every_n_steps: Optional[int] = None, + traversal_repeats: int = 100, + begin_first_step: bool = False, + plt_block_size: float = 1.25, + plt_show: bool = False, + plt_transpose: bool = False, + log_wandb: bool = True, + batch_size: int = 128, + include_factor_dists: bool = True, + ): + assert traversal_repeats > 0 + self._traversal_repeats = traversal_repeats + self._seed = seed + self._recon_loss = make_reconstruction_loss('mse', 'mean') + self._plt_block_size = plt_block_size + self._plt_show = plt_show + self._log_wandb = log_wandb + self._include_gt_factor_dists = include_factor_dists + self._transpose_plot = plt_transpose + self._batch_size = batch_size + super().__init__(every_n_steps, begin_first_step) + + @torch.no_grad() + def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + # get dataset and vae framework from trainer and module + dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True) + # exit early + if not dataset.is_ground_truth: + log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!') + return + # get aggregate function + dists_names, dists_fn = _get_dists_fn(vae, self._recon_loss) + if (dists_names is None) or (dists_fn is None): + log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') + return + # compute various distances matrices for each factor + dists_names, f_grid = compute_factor_distances( + dataset=dataset, + dists_fn=dists_fn, + dists_names=dists_names, + traversal_repeats=self._traversal_repeats, + batch_size=self._batch_size, + include_gt_factor_dists=self._include_gt_factor_dists, + transform_batch=lambda batch: batch.to(vae.device), + seed=self._seed, + data_mode='input', + ) + # plot these results + fig, axs = plt_factor_distances( + gt_data=dataset.gt_data, + f_grid=f_grid, + dists_names=dists_names, + title=f'{vae.__class__.__name__}: {dataset.gt_data.name.capitalize()} Distances', + plt_block_size=self._plt_block_size, + plt_transpose=self._transpose_plot, + plt_cmap='Blues', + ) + # show the plot + if self._plt_show: + plt.show() + # log the plot to wandb + if self._log_wandb: + wb_log_metrics(trainer.logger, { + 'factor_distances': wandb.Image(fig) + }) + + +def _normalize_min_max_mean_std_to_min_max(recon_min, recon_max, recon_mean, recon_std) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]: + # check recon_min and recon_max + if (recon_min is not None) or (recon_max is not None): + if (recon_mean is not None) or (recon_std is not None): + raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') + if (recon_min is None) or (recon_max is None): + raise ValueError('both recon_min & recon_max must be specified') + # check strings + if isinstance(recon_min, str) or isinstance(recon_max, str): + if not (isinstance(recon_min, str) and isinstance(recon_max, str)): + raise ValueError('both recon_min & recon_max must be "auto" if one is "auto"') + return None, None + # check recon_mean and recon_std + elif (recon_mean is not None) or (recon_std is not None): + if (recon_min is not None) or (recon_max is not None): + raise ValueError('must choose either recon_min & recon_max OR recon_mean & recon_std, cannot specify both') + if (recon_mean is None) or (recon_std is None): + raise ValueError('both recon_mean & recon_std must be specified') + # set values: + # | ORIG: [0, 1] + # | TRANSFORM: (x - mean) / std -> [(0-mean)/std, (1-mean)/std] + # | REVERT: (x - min) / (max - min) -> [0, 1] + # | min=(0-mean)/std, max=(1-mean)/std + recon_mean, recon_std = np.array(recon_mean, dtype='float32'), np.array(recon_std, dtype='float32') + recon_min = np.divide(0 - recon_mean, recon_std) + recon_max = np.divide(1 - recon_mean, recon_std) + # set defaults + if recon_min is None: recon_min = 0.0 + if recon_max is None: recon_max = 0.0 + # change type + recon_min = np.array(recon_min) + recon_max = np.array(recon_max) + assert recon_min.ndim in (0, 1) + assert recon_max.ndim in (0, 1) + # checks + assert np.all(recon_min < np.all(recon_max)), f'recon_min={recon_min} must be less than recon_max={recon_max}' + return recon_min, recon_max + class VaeLatentCycleLoggingCallback(BaseCallbackPeriodic): - def __init__(self, seed=7777, every_n_steps=None, begin_first_step=False, mode='fitted_gaussian_cycle', plt_show=False, plt_block_size=1.0, recon_min: Union[int, Literal['auto']] = 0., recon_max: Union[int, Literal['auto']] = 1.): + def __init__( + self, + seed: Optional[int] = 7777, + every_n_steps: Optional[int] = None, + begin_first_step: bool = False, + num_frames: int = 17, + mode: str = 'fitted_gaussian_cycle', + wandb_mode: str = 'both', + wandb_fps: int = 4, + plt_show: bool = False, + plt_block_size: float = 1.0, + # recon_min & recon_max + recon_min: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1] + recon_max: Optional[Union[int, Literal['auto']]] = None, # scale data in this range [min, max] to [0, 1] + recon_mean: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] + recon_std: Optional[Union[Tuple[float, ...], float]] = None, # automatically converted to min & max [(0-mean)/std, (1-mean)/std], assuming original range of values is [0, 1] + ): super().__init__(every_n_steps, begin_first_step) self.seed = seed self.mode = mode self.plt_show = plt_show self.plt_block_size = plt_block_size - self._recon_min = recon_min - self._recon_max = recon_max + self._wandb_mode = wandb_mode + self._num_frames = num_frames + self._fps = wandb_fps + # checks + assert wandb_mode in {'none', 'img', 'vid', 'both'}, f'invalid wandb_mode={repr(wandb_mode)}, must be one of: ("none", "img", "vid", "both")' + # normalize + self._recon_min, self._recon_max = _normalize_min_max_mean_std_to_min_max( + recon_min=recon_min, + recon_max=recon_max, + recon_mean=recon_mean, + recon_std=recon_std, + ) + def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): # get dataset and vae framework from trainer and module dataset, vae = _get_dataset_and_vae(trainer, pl_module, unwrap_groundtruth=True) + # TODO: should this not use `visualize_dataset_traversal`? + with torch.no_grad(): # get random sample of z_means and z_logvars for computing the range of values for the latent_cycle with TempNumpySeed(self.seed): - obs = dataset.dataset_sample_batch(64, mode='input').to(vae.device) + batch = dataset.dataset_sample_batch(64, mode='input').to(vae.device) # get representations if isinstance(vae, Vae): # variational auto-encoder - ds_posterior, ds_prior = vae.encode_dists(obs) + ds_posterior, ds_prior = vae.encode_dists(batch) zs_mean, zs_logvar = ds_posterior.mean, torch.log(ds_posterior.variance) - else: + elif isinstance(vae, Ae): # auto-encoder - zs_mean = vae.encode(obs) + zs_mean = vae.encode(batch) zs_logvar = torch.ones_like(zs_mean) + else: + log.warning(f'cannot run {self.__class__.__name__}, unsupported type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') + return # get min and max if auto - if (self._recon_min == 'auto') or (self._recon_max == 'auto'): - if self._recon_min == 'auto': self._recon_min = float(torch.min(obs).cpu()) - if self._recon_max == 'auto': self._recon_max = float(torch.max(obs).cpu()) - log.info(f'auto visualisation min: {self._recon_min} and max: {self._recon_max} obtained from {len(obs)} samples') + if (self._recon_min is None) or (self._recon_max is None): + if self._recon_min is None: self._recon_min = float(torch.min(batch).cpu()) + if self._recon_max is None: self._recon_max = float(torch.max(batch).cpu()) + log.info(f'auto visualisation min: {self._recon_min} and max: {self._recon_max} obtained from {len(batch)} samples') # produce latent cycle grid animation # TODO: this needs to be fixed to not use logvar, but rather the representations or distributions themselves - frames, stills = latent_cycle_grid_animation( + animation, stills = latent_cycle_grid_animation( vae.decode, zs_mean, zs_logvar, - mode=self.mode, num_frames=21, decoder_device=vae.device, tensor_style_channels=False, return_stills=True, + mode=self.mode, num_frames=self._num_frames, decoder_device=vae.device, tensor_style_channels=False, return_stills=True, to_uint8=True, recon_min=self._recon_min, recon_max=self._recon_max, ) + image = make_image_grid(stills.reshape(-1, *stills.shape[2:]), num_cols=stills.shape[1], pad=4) - # log video - wb_log_metrics(trainer.logger, { - self.mode: wandb.Video(np.transpose(frames, [0, 3, 1, 2]), fps=4, format='mp4'), - }) + # log video -- none, img, vid, both + wandb_items = {} + if self._wandb_mode in ('img', 'both'): wandb_items[f'{self.mode}_img'] = wandb.Image(image) + if self._wandb_mode in ('vid', 'both'): wandb_items[f'{self.mode}_vid'] = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=self._fps, format='mp4'), + wb_log_metrics(trainer.logger, wandb_items) + # log locally if self.plt_show: - grid = make_image_grid(np.reshape(stills, (-1, *stills.shape[2:])), num_cols=stills.shape[1], pad=4) fig, ax = plt.subplots(1, 1, figsize=(self.plt_block_size*stills.shape[1], self.plt_block_size*stills.shape[0])) - ax.imshow(grid) + ax.imshow(image) ax.axis('off') fig.tight_layout() plt.show() -class VaeDisentanglementLoggingCallback(BaseCallbackPeriodic): +class VaeMetricLoggingCallback(BaseCallbackPeriodic): - def __init__(self, step_end_metrics=None, train_end_metrics=None, every_n_steps=None, begin_first_step=False): + def __init__( + self, + step_end_metrics: Optional[Sequence[str]] = None, + train_end_metrics: Optional[Sequence[str]] = None, + every_n_steps: Optional[int] = None, + begin_first_step: bool = False, + ): super().__init__(every_n_steps, begin_first_step) self.step_end_metrics = step_end_metrics if step_end_metrics else [] self.train_end_metrics = train_end_metrics if train_end_metrics else [] diff --git a/disent/util/math/random.py b/disent/util/math/random.py index 0b10fc8b..2e7f86b3 100644 --- a/disent/util/math/random.py +++ b/disent/util/math/random.py @@ -31,12 +31,15 @@ # ========================================================================= # -def random_choice_prng(a, size=None, replace=True): +def random_choice_prng(a, size=None, replace=True, seed: int = None): + # generate a random seed + if seed is None: + seed = np.random.randint(0, 2**32) # create seeded pseudo random number generator # - built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672 # - PCG64 is the default: https://numpy.org/doc/stable/reference/random/bit_generators/index.html # - PCG64 has good statistical properties and is fast: https://numpy.org/doc/stable/reference/random/performance.html - g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32))) + g = np.random.Generator(np.random.PCG64(seed=seed)) # sample indices choices = g.choice(a, size=size, replace=replace) # done! diff --git a/disent/util/profiling.py b/disent/util/profiling.py index 2721387b..3befbdf8 100644 --- a/disent/util/profiling.py +++ b/disent/util/profiling.py @@ -36,12 +36,17 @@ # ========================================================================= # -def get_memory_usage(): +def get_memory_usage(pretty: bool = False): import os import psutil process = psutil.Process(os.getpid()) num_bytes = process.memory_info().rss # in bytes - return num_bytes + # format the bytes + if pretty: + from disent.util.strings.fmt import bytes_to_human + return bytes_to_human(num_bytes) + else: + return num_bytes # ========================================================================= # @@ -106,6 +111,11 @@ def __exit__(self, *args, **kwargs): else: log.log(self._log_level, f'{self.name}: {self.pretty}') + def restart(self): + assert self._start_time is not None, 'timer must have been started before we can restart it' + assert self._end_time is None, 'timer cannot be restarted if it is finished' + self._start_time = time.time_ns() + @property def elapsed_ns(self) -> int: if self._start_time is not None: @@ -159,4 +169,3 @@ def prettify_time(ns: int) -> str: # ========================================================================= # # END # # ========================================================================= # - diff --git a/disent/util/visualize/plot.py b/disent/util/visualize/plot.py new file mode 100644 index 00000000..f4bafb81 --- /dev/null +++ b/disent/util/visualize/plot.py @@ -0,0 +1,383 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +from numbers import Number +from typing import Any +from typing import Dict +from typing import Optional + +import numpy as np +import torch +import logging + +from disent.dataset import DisentDataset +from disent.dataset.util.state_space import NonNormalisedFactors +from disent.util.seeds import TempNumpySeed +from disent.util.visualize.vis_util import make_animated_image_grid +from disent.util.visualize.vis_util import make_image_grid + +# TODO: matplotlib is not in requirements +from matplotlib import pyplot as plt + + +log = logging.getLogger(__name__) + + +# ========================================================================= # +# images # +# ========================================================================= # + + +# TODO: similar functions exist: output_image, to_img, to_imgs, reconstructions_to_images +def to_img(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> torch.Tensor: + assert x.ndim == 3, 'image must have 3 dimensions: (C, H, W)' + return to_imgs(x, scale=scale, to_cpu=to_cpu, move_channels=move_channels) + + +# TODO: similar functions exist: output_image, to_img, to_imgs, reconstructions_to_images +def to_imgs(x: torch.Tensor, scale=False, to_cpu=True, move_channels=True) -> torch.Tensor: + # (..., C, H, W) + assert x.ndim >= 3, 'image must have 3 or more dimensions: (..., C, H, W)' + assert x.dtype in {torch.float16, torch.float32, torch.float64, torch.complex32, torch.complex64}, f'unsupported dtype: {x.dtype}' + # no gradient + with torch.no_grad(): + # imaginary to real + if x.dtype in {torch.complex32, torch.complex64}: + x = torch.abs(x) + # scale images + if scale: + m = x.min(dim=-3, keepdim=True).values.min(dim=-2, keepdim=True).values.min(dim=-1, keepdim=True).values + M = x.max(dim=-3, keepdim=True).values.max(dim=-2, keepdim=True).values.max(dim=-1, keepdim=True).values + x = (x - m) / (M - m) + # move axis + if move_channels: + x = torch.moveaxis(x, -3, -1) + # to uint8 + x = torch.clamp(x, 0, 1) + x = (x * 255).to(torch.uint8) + # done! + x = x.detach() # is this needeed? + if to_cpu: + x = x.cpu() + return x + + +# ========================================================================= # +# Matplotlib Helper # +# ========================================================================= # + + +def plt_imshow(img, figsize=12, show=False, **kwargs): + # check image shape + assert img.ndim == 3 + assert img.shape[-1] in (1, 3) + # figure size -- fixed width, adjust height according to image + if isinstance(figsize, (int, str, Number)): + size = np.array(img.shape[:2][::-1]) + figsize = tuple(size / size[0] * figsize) + # create plot + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize, **kwargs) + plt_hide_axis(ax) + ax.imshow(img) + fig.tight_layout() + # done! + if show: + plt.show() + return fig, ax + + +def _hide(hide, cond): + assert hide in {True, False, 'all', 'edges', 'none'} + return (hide is True) or (hide == 'all') or (hide == 'edges' and cond) + + +def plt_subplots( + nrows: int = 1, ncols: int = 1, + # custom + title=None, + titles=None, + row_labels=None, + col_labels=None, + title_size: int = None, + titles_size: int = None, + label_size: int = None, + hide_labels='edges', # none, edges, all + hide_axis='edges', # none, edges, all + # plt.subplots: + sharex: str = False, + sharey: str = False, + subplot_kw=None, + gridspec_kw=None, + **fig_kw, +): + assert isinstance(nrows, int) + assert isinstance(ncols, int) + # check titles + if titles is not None: + titles = np.array(titles).reshape([nrows, ncols]) + # get labels + if (row_labels is None) or isinstance(row_labels, str): + row_labels = [row_labels] * nrows + if (col_labels is None) or isinstance(col_labels, str): + col_labels = [col_labels] * ncols + assert len(row_labels) == nrows, 'row_labels and nrows mismatch' + assert len(col_labels) == ncols, 'row_labels and nrows mismatch' + # check titles + if titles is not None: + assert len(titles) == nrows + assert len(titles[0]) == ncols + # create subplots + fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=False, subplot_kw=subplot_kw, gridspec_kw=gridspec_kw, **fig_kw) + # generate + for y in range(nrows): + for x in range(ncols): + ax = axs[y, x] + plt_hide_axis(ax, hide_xaxis=_hide(hide_axis, y != nrows-1), hide_yaxis=_hide(hide_axis, x != 0)) + # modify ax + if not _hide(hide_labels, y != nrows-1): + ax.set_xlabel(col_labels[x], fontsize=label_size) + if not _hide(hide_labels, x != 0): + ax.set_ylabel(row_labels[y], fontsize=label_size) + # set title + if titles is not None: + ax.set_title(titles[y][x], fontsize=titles_size) + # set title + fig.suptitle(title, fontsize=title_size) + # done! + return fig, axs + + +def plt_subplots_imshow( + grid, + # custom: + title=None, + titles=None, + row_labels=None, + col_labels=None, + title_size: int = None, + titles_size: int = None, + label_size: int = None, + hide_labels='edges', # none, edges, all + hide_axis='all', # none, edges, all + # tight_layout: + subplot_padding: Optional[float] = 1.08, + # plt.subplots: + sharex: str = False, + sharey: str = False, + subplot_kw=None, + gridspec_kw=None, + # imshow + vmin: float = None, + vmax: float = None, + # extra + show: bool = False, + imshow_kwargs: dict = None, + **fig_kw, +): + # TODO: add automatic height & width + fig, axs = plt_subplots( + nrows=len(grid), ncols=len(grid[0]), + # custom + title=title, + titles=titles, + row_labels=row_labels, + col_labels=col_labels, + title_size=title_size, + titles_size=titles_size, + label_size=label_size, + hide_labels=hide_labels, # none, edges, all + hide_axis=hide_axis, # none, edges, all + # plt.subplots: + sharex=sharex, + sharey=sharey, + subplot_kw=subplot_kw, + gridspec_kw=gridspec_kw, + **fig_kw, + ) + # show images + for y, x in np.ndindex(axs.shape): + axs[y, x].imshow(grid[y][x], vmin=vmin, vmax=vmax, **(imshow_kwargs if imshow_kwargs else {})) + fig.tight_layout(**({} if (subplot_padding is None) else dict(pad=subplot_padding))) + # done! + if show: + plt.show() + return fig, axs + + +def plt_hide_axis(ax, hide_xaxis=True, hide_yaxis=True, hide_border=True, hide_axis_labels=False, hide_axis_ticks=True, hide_grid=True): + if hide_xaxis: + if hide_axis_ticks: + ax.set_xticks([]) + ax.set_xticklabels([]) + if hide_axis_labels: + ax.xaxis.label.set_visible(False) + if hide_yaxis: + if hide_axis_ticks: + ax.set_yticks([]) + ax.set_yticklabels([]) + if hide_axis_labels: + ax.yaxis.label.set_visible(False) + if hide_border: + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + if hide_grid: + ax.grid(False) + return ax + + +# ========================================================================= # +# Dataset Visualisation / Traversals # +# ========================================================================= # + + +def visualize_dataset_traversal( + dataset: DisentDataset, + # inputs + factor_names: Optional[NonNormalisedFactors] = None, + num_frames: int = 9, + seed: int = 777, + base_factors=None, + traverse_mode='cycle', + # images & animations + pad: int = 4, + border: bool = True, + bg_color: Number = None, + # augment + augment_fn: callable = None, + data_mode: str = 'raw', + # output + output_wandb: bool = False, +): + """ + Generic function that can return multiple parts of the dataset & factor traversal pipeline. + - This only evaluates what is needed to compute the next components. + - The returned grid, image and animation will always have 3 channels, RGB + + Tasks include: + - factor_idxs + - factors + - grid + - image + - image_wandb + - image_plt + - animation + - animation_wandb + """ + + # get factors from dataset + factor_idxs = dataset.gt_data.normalise_factor_idxs(factor_names) + + # get factor traversals + with TempNumpySeed(seed): + factors = np.stack([ + dataset.gt_data.sample_random_factor_traversal(f_idx, base_factors=base_factors, num=num_frames, mode=traverse_mode) + for f_idx in factor_idxs + ], axis=0) + + # retrieve and augment image grid + grid = [dataset.dataset_batch_from_factors(f, mode=data_mode) for f in factors] + if augment_fn is not None: + grid = [augment_fn(batch) for batch in grid] + grid = np.stack(grid, axis=0) + + # TODO: this is kinda hacky, maybe rather add a check? + # TODO: can this be moved into the `output_wandb` if statement? + # - animations glitch out if they do not have 3 channels + if grid.shape[-1] == 1: + grid = grid.repeat(3, axis=-1) + + assert grid.ndim == 5 + assert grid.shape[-1] in (1, 3) + + # generate visuals + image = make_image_grid(np.concatenate(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=num_frames) + animation = make_animated_image_grid(np.stack(grid, axis=0), pad=pad, border=border, bg_color=bg_color, num_cols=None) + + # convert to wandb + if output_wandb: + import wandb + wandb_image = wandb.Image(image) + wandb_animation = wandb.Video(np.transpose(animation, [0, 3, 1, 2]), fps=4, format='mp4') + return ( + wandb_image, + wandb_animation, + ) + + # return values + return ( + grid, # (FACTORS, NUM_FRAMES, H, W, C) + image, # ([[H+PAD]*[FACTORS+1]], [[W+PAD]*[NUM_FRAMES+1]], C) + animation, # (NUM_FRAMES, [H & FACTORS], [W & FACTORS], C) -- size is auto-chosen + ) + + +# ========================================================================= # +# 2d density plot # +# ========================================================================= # + + +def plt_2d_density( + x, + y, + n_bins: int = 300, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + ax: plt.Subplot = None, + pcolormesh_kwargs: Optional[Dict[str, Any]] = None +): + from scipy.stats import kde + # https://www.python-graph-gallery.com/85-density-plot-with-matplotlib + # convert inputs + x = np.array(x) + y = np.array(y) + # prevent singular + # x = np.random.randn(*x.shape) * (0.01 * max(x.max() - x.min(), 1)) + # y = np.random.randn(*y.shape) * (0.01 * max(y.max() - y.min(), 1)) + # get bounds + if xmin is None: xmin = x.min() + if xmax is None: xmax = x.max() + if ymin is None: ymin = y.min() + if ymax is None: ymax = y.max() + # Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents + xi, yi = np.mgrid[xmin:xmax:n_bins*1j, ymin:ymax:n_bins*1j] + try: + k = kde.gaussian_kde([x, y]) + zi = k(np.stack([xi.flatten(), yi.flatten()], axis=0)) + except np.linalg.LinAlgError: + log.warning('Could not create 2d_density plot') + return + # update args + if ax is None: ax = plt + if pcolormesh_kwargs is None: pcolormesh_kwargs = {} + # Make the plot + ax.pcolormesh(xi, yi, zi.reshape(xi.shape), shading='auto', **pcolormesh_kwargs) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/disent/util/visualize/vis_model.py b/disent/util/visualize/vis_model.py index ae5fd182..b7735ce9 100644 --- a/disent/util/visualize/vis_model.py +++ b/disent/util/visualize/vis_model.py @@ -103,7 +103,9 @@ def _z_minmax_interval_cycle(base_z, z_means, z_logvars, z_idx, num_frames): # ========================================================================= # -def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_animations=4, num_frames=20, decoder_device=None, recon_min=0., recon_max=1.): +# TODO: this function should not convert output to images, it should just be +# left as is. That way we don't need to pass in the recon_min and recon_max +def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_animations=4, num_frames=20, decoder_device=None, recon_min=0., recon_max=1.) -> np.ndarray: assert len(z_means) > 1 and len(z_logvars) > 1, 'not enough samples to average' # convert z_means, z_logvars = to_numpy(z_means), to_numpy(z_logvars) @@ -117,12 +119,12 @@ def latent_cycle(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', for j in range(z_means.shape[1]): z = z_gen_func(base_z, z_means, z_logvars, j, num_frames) z = torch.as_tensor(z, device=decoder_device) - frames.append(reconstructions_to_images(decoder_func(z), recon_min=recon_min, recon_max=recon_max)) + frames.append(decoder_func(z)) animations.append(frames) - return to_numpy(animations) + return reconstructions_to_images(animations, recon_min=recon_min, recon_max=recon_max) -def latent_cycle_grid_animation(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_frames=21, pad=4, border=True, bg_color=0.5, decoder_device=None, tensor_style_channels=True, always_rgb=True, return_stills=False, to_uint8=False, recon_min=0., recon_max=1.): +def latent_cycle_grid_animation(decoder_func, z_means, z_logvars, mode='fixed_interval_cycle', num_frames=21, pad=4, border=True, bg_color=0.5, decoder_device=None, tensor_style_channels=True, always_rgb=True, return_stills=False, to_uint8=False, recon_min=0., recon_max=1.) -> np.ndarray: # produce latent cycle animation & merge frames stills = latent_cycle(decoder_func, z_means, z_logvars, mode=mode, num_animations=1, num_frames=num_frames, decoder_device=decoder_device, recon_min=recon_min, recon_max=recon_max)[0] # check and add missing channel if needed (convert greyscale to rgb images) diff --git a/disent/util/visualize/vis_util.py b/disent/util/visualize/vis_util.py index 26bf920e..7d46abe2 100644 --- a/disent/util/visualize/vis_util.py +++ b/disent/util/visualize/vis_util.py @@ -24,6 +24,8 @@ import logging import warnings +from typing import List +from typing import Union import numpy as np import scipy.stats @@ -217,7 +219,14 @@ def cycle_interval(starting_value, num_frames, min_val, max_val): # TODO: this functionality is duplicated elsewhere! # TODO: similar functions exist: output_image, to_img, to_imgs, reconstructions_to_images -def reconstructions_to_images(recon, mode='float', moveaxis=True, recon_min=0., recon_max=1.): +def reconstructions_to_images( + recon, + mode: str = 'float', + moveaxis: bool = True, + recon_min: Union[float, List[float]] = 0.0, + recon_max: Union[float, List[float]] = 1.0, + warn_if_clipped: bool = True, +) -> Union[np.ndarray, Image.Image]: """ Convert a batch of reconstructions to images. A batch in this case consists of an arbitrary number of dimensions of an array, @@ -232,16 +241,23 @@ def reconstructions_to_images(recon, mode='float', moveaxis=True, recon_min=0., # checks assert img.ndim >= 3 assert img.dtype in (np.float32, np.float64) + # move channels axis + if moveaxis: + img = np.moveaxis(img, -3, -1) + # check min and max + recon_min = np.array(recon_min) + recon_max = np.array(recon_max) + assert recon_min.shape == recon_max.shape + assert recon_min.ndim in (0, 1) # supports channels or glbal min . max # scale image img = (img - recon_min) / (recon_max - recon_min) # check image bounds - if np.min(img) < 0 or np.max(img) > 1: - warnings.warn('images are being clipped between 0 and 1') + if warn_if_clipped: + m, M = np.min(img), np.max(img) + if m < 0 or M > 1: + log.warning(f'images with range [{m}, {M}] have been clipped to the range [0, 1]') + # do clipping img = np.clip(img, 0, 1) - # move channels axis - if moveaxis: - # TODO: automatically detect - img = np.moveaxis(img, -3, -1) # convert if mode == 'float': return img diff --git a/docs/examples/mnist_example.py b/docs/examples/mnist_example.py index f7e4d676..6e56354d 100644 --- a/docs/examples/mnist_example.py +++ b/docs/examples/mnist_example.py @@ -1,6 +1,5 @@ import os import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from torchvision import datasets from tqdm import tqdm @@ -9,7 +8,7 @@ from disent.frameworks.vae import AdaVae from disent.model import AutoEncoder from disent.model.ae import DecoderFC, EncoderFC -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 # modify the mnist dataset to only return images, not labels @@ -21,8 +20,8 @@ def __getitem__(self, index): # make mnist dataset -- adjust num_samples here to match framework. TODO: add tests that can fail with a warning -- dataset downloading is not always reliable data_folder = os.path.abspath(os.path.join(__file__, '../data/dataset')) -dataset_train = DisentDataset(MNIST(data_folder, train=True, download=True, transform=ToStandardisedTensor()), sampler=RandomSampler(num_samples=2)) -dataset_test = MNIST(data_folder, train=False, download=True, transform=ToStandardisedTensor()) +dataset_train = DisentDataset(MNIST(data_folder, train=True, download=True, transform=ToImgTensorF32()), sampler=RandomSampler(num_samples=2)) +dataset_test = MNIST(data_folder, train=False, download=True, transform=ToImgTensorF32()) # create the dataloaders dataloader_train = DataLoader(dataset=dataset_train, batch_size=128, shuffle=True, num_workers=os.cpu_count()) diff --git a/docs/examples/overview_data.py b/docs/examples/overview_data.py index 8da7e17d..ebd38818 100644 --- a/docs/examples/overview_data.py +++ b/docs/examples/overview_data.py @@ -1,9 +1,9 @@ from disent.dataset.data import XYObjectData -data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') print(f'Number of observations: {len(data)} == {data.size}') -print(f'Observation shape: {data.observation_shape}') +print(f'Observation shape: {data.img_shape}') print(f'Num Factors: {data.num_factors}') print(f'Factor Names: {data.factor_names}') print(f'Factor Sizes: {data.factor_sizes}') diff --git a/docs/examples/overview_dataset_loader.py b/docs/examples/overview_dataset_loader.py index 5646df02..21aca53c 100644 --- a/docs/examples/overview_dataset_loader.py +++ b/docs/examples/overview_dataset_loader.py @@ -2,11 +2,11 @@ from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairOrigSampler -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 # prepare the data -data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') -dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') +dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) # iterate over single epoch diff --git a/docs/examples/overview_dataset_pair.py b/docs/examples/overview_dataset_pair.py index 96e1f155..7f47bd09 100644 --- a/docs/examples/overview_dataset_pair.py +++ b/docs/examples/overview_dataset_pair.py @@ -1,12 +1,12 @@ from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairOrigSampler -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 # prepare the data -data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') -dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') +dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) # iterate over single epoch for obs in dataset: diff --git a/docs/examples/overview_dataset_pair_augment.py b/docs/examples/overview_dataset_pair_augment.py index 81b494c7..2dd14d11 100644 --- a/docs/examples/overview_dataset_pair_augment.py +++ b/docs/examples/overview_dataset_pair_augment.py @@ -1,12 +1,12 @@ from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairSampler -from disent.nn.transform import ToStandardisedTensor, FftBoxBlur +from disent.dataset.transform import ToImgTensorF32, FftBoxBlur # prepare the data -data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') -dataset = DisentDataset(data, sampler=GroundTruthPairSampler(), transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0)) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') +dataset = DisentDataset(data, sampler=GroundTruthPairSampler(), transform=ToImgTensorF32(), augment=FftBoxBlur(radius=1, p=1.0)) # iterate over single epoch for obs in dataset: diff --git a/docs/examples/overview_dataset_single.py b/docs/examples/overview_dataset_single.py index 120b4268..1ce113cd 100644 --- a/docs/examples/overview_dataset_single.py +++ b/docs/examples/overview_dataset_single.py @@ -5,7 +5,7 @@ # - DisentDataset is a generic wrapper around torch Datasets that prepares # the data for the various frameworks according to some sampling strategy # by default this sampling strategy just returns the data at the given idx. -data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') dataset = DisentDataset(data, transform=None, augment=None) # iterate over single epoch diff --git a/docs/examples/overview_framework_adagvae.py b/docs/examples/overview_framework_adagvae.py index 11c88009..e0d0d530 100644 --- a/docs/examples/overview_framework_adagvae.py +++ b/docs/examples/overview_framework_adagvae.py @@ -1,5 +1,4 @@ import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData @@ -7,13 +6,13 @@ from disent.frameworks.vae import AdaVae from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.util import is_test_run # you can ignore and remove this # prepare the data data = XYObjectData() -dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) +dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) # create the pytorch lightning system diff --git a/docs/examples/overview_framework_ae.py b/docs/examples/overview_framework_ae.py index d73124d1..f848e305 100644 --- a/docs/examples/overview_framework_ae.py +++ b/docs/examples/overview_framework_ae.py @@ -1,19 +1,17 @@ import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData -from disent.dataset.sampling import SingleSampler from disent.frameworks.ae import Ae from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.util import is_test_run # you can ignore and remove this # prepare the data data = XYObjectData() -dataset = DisentDataset(data, transform=ToStandardisedTensor()) +dataset = DisentDataset(data, transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) # create the pytorch lightning system diff --git a/docs/examples/overview_framework_betavae.py b/docs/examples/overview_framework_betavae.py index 14974506..478eb421 100644 --- a/docs/examples/overview_framework_betavae.py +++ b/docs/examples/overview_framework_betavae.py @@ -1,19 +1,17 @@ import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData -from disent.dataset.sampling import SingleSampler from disent.frameworks.vae import BetaVae from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.util import is_test_run # you can ignore and remove this # prepare the data data = XYObjectData() -dataset = DisentDataset(data, transform=ToStandardisedTensor()) +dataset = DisentDataset(data, transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) # create the pytorch lightning system diff --git a/docs/examples/overview_framework_betavae_scheduled.py b/docs/examples/overview_framework_betavae_scheduled.py index 4b0d8550..290d55b2 100644 --- a/docs/examples/overview_framework_betavae_scheduled.py +++ b/docs/examples/overview_framework_betavae_scheduled.py @@ -1,19 +1,17 @@ import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData -from disent.dataset.sampling import SingleSampler from disent.frameworks.vae import BetaVae from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.schedule import CyclicSchedule from disent.util import is_test_run # you can ignore and remove this # prepare the data data = XYObjectData() -dataset = DisentDataset(data, transform=ToStandardisedTensor()) +dataset = DisentDataset(data, transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) # create the pytorch lightning system diff --git a/docs/examples/overview_metrics.py b/docs/examples/overview_metrics.py index afafff2b..cd83103f 100644 --- a/docs/examples/overview_metrics.py +++ b/docs/examples/overview_metrics.py @@ -1,18 +1,16 @@ import pytorch_lightning as pl -from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset from disent.dataset.data import XYObjectData -from disent.dataset.sampling import SingleSampler from disent.frameworks.vae import BetaVae from disent.model import AutoEncoder from disent.model.ae import DecoderConv64, EncoderConv64 -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.metrics import metric_dci, metric_mig from disent.util import is_test_run data = XYObjectData() -dataset = DisentDataset(data, transform=ToStandardisedTensor(), augment=None) +dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None) dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True) def make_vae(beta): diff --git a/experiment/config/augment/basic.yaml b/experiment/config/augment/basic.yaml index 71fb89e0..2f44ce41 100644 --- a/experiment/config/augment/basic.yaml +++ b/experiment/config/augment/basic.yaml @@ -1,6 +1,6 @@ -# @package _group_ name: basic -transform: + +augment_cls: _target_: torchvision.transforms.RandomOrder transforms: - _target_: kornia.augmentation.ColorJitter @@ -40,4 +40,4 @@ transform: # degrees: 10 # translate: [0.14, 0.14] # scale: [0.95, 1.05] -# shear: 5 \ No newline at end of file +# shear: 5 diff --git a/experiment/config/augment/none.yaml b/experiment/config/augment/none.yaml index 15bfee47..013f4a6b 100644 --- a/experiment/config/augment/none.yaml +++ b/experiment/config/augment/none.yaml @@ -1,3 +1,3 @@ -# @package _group_ name: none -transform: NULL + +augment_cls: NULL diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index a06300bd..5f25a949 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -1,57 +1,47 @@ defaults: + # data + - sampling: default__bb + - dataset: xyobject + - augment: none # system - - framework: adavae + - framework: betavae - model: vae_conv64 + # training - optimizer: adam - schedule: none - # data - - dataset: xyobject - - dataset_sampling: full_bb - - augment: none - # runtime - - metrics: fast - - run_length: short - - run_location: local + - metrics: all + - run_length: long + # logs - run_callbacks: vis - run_logging: wandb - # plugins - - hydra/job_logging: colorlog - - hydra/hydra_logging: colorlog - - hydra/launcher: submitit_slurm - -job: - user: '${env:USER}' - project: 'test-project' - name: '${framework.name}:${framework.module.recon_loss}|${dataset.name}:${dataset_sampling.name}|${trainer.steps}' - partition: batch - seed: NULL - -framework: - beta: 4 - module: - recon_loss: mse - loss_reduction: mean_sum - - # only some frameworks support these features - optional: - latent_distribution: normal # only used by VAEs - overlap_loss: NULL - -model: - z_size: 25 - -optimizer: - lr: 1e-3 - -# CUSTOM DEFAULTS SPECIALIZATION -# - This key is deleted on load and the correct key on the root config is set similar to defaults. -# - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. -specializations: - # default samplers -- the framework specified defaults - dataset_sampler: ${dataset.data_type}_${framework.data_sample_mode} - - # newer samplers -- only active for frameworks that require 3 observations, otherwise random for 2, and exact for 1 - # dataset_sampler: gt_dist_${framework.data_sample_mode} - - # random samplers -- force random sampling of observation pairs or triples - # dataset_sampler: random_${framework.data_sample_mode} + # runtime + - run_location: stampede_shr + - run_launcher: slurm + - run_action: train + # entries in this file override entries from default lists + - _self_ + +settings: + job: + user: '${oc.env:USER}' + project: 'DELETE' + name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' + seed: NULL + + framework: + beta: 0.0316 + recon_loss: mse + loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling` + + framework_opt: + latent_distribution: normal # only used by VAEs + + model: + z_size: 25 + weight_init: 'xavier_normal' # xavier_normal, default + + dataset: + batch_size: 256 + + optimizer: + lr: 1e-3 diff --git a/experiment/config/config_test.yaml b/experiment/config/config_test.yaml index ca54be57..63a66c8a 100644 --- a/experiment/config/config_test.yaml +++ b/experiment/config/config_test.yaml @@ -1,48 +1,47 @@ defaults: - # experiment + # data + - sampling: default__bb + - dataset: xyobject + - augment: none + # system - framework: betavae - model: vae_conv64 + # training - optimizer: adam - - dataset: xyobject - - dataset_sampling: full_bb - - augment: none - - schedule: none + - schedule: beta_cyclic - metrics: test - # runtime - run_length: test - - run_location: local_cpu - - run_callbacks: vis_slow + # logs + - run_callbacks: test - run_logging: none - # plugins - - hydra/job_logging: colorlog - - hydra/hydra_logging: colorlog - - hydra/launcher: submitit_slurm + # runtime + - run_location: local_cpu + - run_launcher: local + - run_action: train + # entries in this file override entries from default lists + - _self_ + +settings: + job: + user: 'invalid' + project: 'invalid' + name: '${framework.name}:${settings.framework.recon_loss}|${dataset.name}:${sampling.name}|${trainer.max_steps}' + seed: NULL -job: - user: invalid - project: invalid - name: '${framework.name}:${framework.module.recon_loss}|${dataset.name}:${dataset_sampling.name}|${trainer.steps}' - partition: invalid - seed: NULL + framework: + beta: 0.0316 + recon_loss: mse + loss_reduction: mean # TODO: this should be renamed to `loss_mode="scaled"` or `enable_loss_scaling=True"` or `enable_beta_scaling` -framework: - beta: 0.001 - module: - recon_loss: mse - loss_reduction: mean - optional: - latent_distribution: normal # only used by VAEs - overlap_loss: NULL + framework_opt: + latent_distribution: normal # only used by VAEs -model: - z_size: 25 + model: + z_size: 25 + weight_init: 'xavier_normal' # xavier_normal, default -optimizer: - lr: 1e-3 + dataset: + batch_size: 5 -# CUSTOM DEFAULTS SPECIALIZATION -# - This key is deleted on load and the correct key on the root config is set similar to defaults. -# - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. -specializations: - dataset_sampler: ${dataset.data_type}_${framework.data_sample_mode} -# dataset_sampler: gt_dist_${framework.data_sample_mode} + optimizer: + lr: 1e-3 diff --git a/experiment/config/dataset/_data_type_/episodes.yaml b/experiment/config/dataset/_data_type_/episodes.yaml new file mode 100644 index 00000000..bd345806 --- /dev/null +++ b/experiment/config/dataset/_data_type_/episodes.yaml @@ -0,0 +1,2 @@ +# controlled by the data's defaults list +name: episodes diff --git a/experiment/config/dataset/_data_type_/gt.yaml b/experiment/config/dataset/_data_type_/gt.yaml new file mode 100644 index 00000000..7c7ed83a --- /dev/null +++ b/experiment/config/dataset/_data_type_/gt.yaml @@ -0,0 +1,2 @@ +# controlled by the data's defaults list +name: gt diff --git a/experiment/config/dataset/_data_type_/random.yaml b/experiment/config/dataset/_data_type_/random.yaml new file mode 100644 index 00000000..6f9aedfa --- /dev/null +++ b/experiment/config/dataset/_data_type_/random.yaml @@ -0,0 +1,2 @@ +# controlled by the data's defaults list +name: random diff --git a/experiment/config/dataset/cars3d.yaml b/experiment/config/dataset/cars3d.yaml index f32c2364..9cb9055e 100644 --- a/experiment/config/dataset/cars3d.yaml +++ b/experiment/config/dataset/cars3d.yaml @@ -1,12 +1,20 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: cars3d + data: _target_: disent.dataset.data.Cars3dData - data_root: ${dataset.data_root} - prepare: True + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + transform: - _target_: disent.nn.transform.ToStandardisedTensor + _target_: disent.dataset.transform.ToImgTensorF32 size: 64 -x_shape: [3, 64, 64] + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] + vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] diff --git a/experiment/config/dataset/dsprites.yaml b/experiment/config/dataset/dsprites.yaml index 0e108a4e..d2ada2dd 100644 --- a/experiment/config/dataset/dsprites.yaml +++ b/experiment/config/dataset/dsprites.yaml @@ -1,12 +1,20 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: dsprites + data: _target_: disent.dataset.data.DSpritesData - data_root: ${dataset.data_root} - prepare: True - in_memory: ${dataset.try_in_memory} + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + in_memory: ${dsettings.dataset.try_in_memory} + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [1, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [1, 64, 64] + vis_mean: [0.042494423521889584] + vis_std: [0.19516645880626055] diff --git a/experiment/config/dataset/monte_rollouts.yaml b/experiment/config/dataset/monte_rollouts.yaml index 91c9314a..682eeb54 100644 --- a/experiment/config/dataset/monte_rollouts.yaml +++ b/experiment/config/dataset/monte_rollouts.yaml @@ -1,13 +1,21 @@ -# @package _group_ +defaults: + - _data_type_: episodes + name: monte_rollouts + data: _target_: disent.dataset.data.EpisodesDownloadZippedPickledData - required_file: ${dataset.data_root}/episodes/monte.pkl + required_file: ${dsettings.storage.data_root}/episodes/monte.pkl download_url: 'https://raw.githubusercontent.com/nmichlo/uploads/main/monte_key.tar.xz' - force_download: FALSE + prepare: ${dsettings.dataset.prepare} + transform: - _target_: disent.nn.transform.ToStandardisedTensor - size: [64, 64] -x_shape: [3, 64, 64] # [3, 210, 160] + _target_: disent.dataset.transform.ToImgTensorF32 + size: [64, 64] # slightly squashed? + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: episodes +meta: + x_shape: [3, 64, 64] # [3, 210, 160] + vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" + vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" diff --git a/experiment/config/dataset/mpi3d_real.yaml b/experiment/config/dataset/mpi3d_real.yaml index 24fbaacd..1e5da193 100644 --- a/experiment/config/dataset/mpi3d_real.yaml +++ b/experiment/config/dataset/mpi3d_real.yaml @@ -1,13 +1,21 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: mpi3d_real + data: _target_: disent.dataset.data.Mpi3dData - data_root: ${dataset.data_root} - prepare: True - in_memory: ${dataset.try_in_memory} + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + in_memory: ${dsettings.dataset.try_in_memory} subset: 'real' + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [3, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627] + vis_std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629] diff --git a/experiment/config/dataset/mpi3d_realistic.yaml b/experiment/config/dataset/mpi3d_realistic.yaml index 5e41ad3d..f1a81300 100644 --- a/experiment/config/dataset/mpi3d_realistic.yaml +++ b/experiment/config/dataset/mpi3d_realistic.yaml @@ -1,13 +1,21 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: mpi3d_realistic + data: _target_: disent.dataset.data.Mpi3dData - data_root: ${dataset.data_root} - prepare: True - in_memory: ${dataset.try_in_memory} + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + in_memory: ${dsettings.dataset.try_in_memory} subset: 'realistic' + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [3, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256] + vis_std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177] diff --git a/experiment/config/dataset/mpi3d_toy.yaml b/experiment/config/dataset/mpi3d_toy.yaml index 90024b18..de6674d2 100644 --- a/experiment/config/dataset/mpi3d_toy.yaml +++ b/experiment/config/dataset/mpi3d_toy.yaml @@ -1,13 +1,21 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: mpi3d_toy + data: _target_: disent.dataset.data.Mpi3dData - data_root: ${dataset.data_root} - prepare: True - in_memory: ${dataset.try_in_memory} + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + in_memory: ${dsettings.dataset.try_in_memory} subset: 'toy' + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [3, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702] + vis_std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426] diff --git a/experiment/config/dataset/shapes3d.yaml b/experiment/config/dataset/shapes3d.yaml index 1ad32106..a834768b 100644 --- a/experiment/config/dataset/shapes3d.yaml +++ b/experiment/config/dataset/shapes3d.yaml @@ -1,12 +1,20 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: 3dshapes + data: _target_: disent.dataset.data.Shapes3dData - data_root: ${dataset.data_root} - prepare: True - in_memory: ${dataset.try_in_memory} + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} + in_memory: ${dsettings.dataset.try_in_memory} + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [3, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] + vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] diff --git a/experiment/config/dataset/smallnorb.yaml b/experiment/config/dataset/smallnorb.yaml index f293805c..9dfbb8ec 100644 --- a/experiment/config/dataset/smallnorb.yaml +++ b/experiment/config/dataset/smallnorb.yaml @@ -1,13 +1,21 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: smallnorb + data: _target_: disent.dataset.data.SmallNorbData - data_root: ${dataset.data_root} - prepare: True + data_root: ${dsettings.storage.data_root} + prepare: ${dsettings.dataset.prepare} is_test: False + transform: - _target_: disent.nn.transform.ToStandardisedTensor + _target_: disent.dataset.transform.ToImgTensorF32 size: 64 -x_shape: [1, 64, 64] + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [1, 64, 64] + vis_mean: [0.7520918401088603] + vis_std: [0.09563879016827262] diff --git a/experiment/config/dataset/xyobject.yaml b/experiment/config/dataset/xyobject.yaml index 861a9b30..9dc791ee 100644 --- a/experiment/config/dataset/xyobject.yaml +++ b/experiment/config/dataset/xyobject.yaml @@ -1,16 +1,18 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: xyobject + data: _target_: disent.dataset.data.XYObjectData - grid_size: 64 - grid_spacing: 1 - min_square_size: 3 - max_square_size: 9 - square_size_spacing: 2 rgb: TRUE - palette: 'colors' + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [3, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [3, 64, 64] + vis_mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] + vis_std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] diff --git a/experiment/config/dataset/xyobject_grey.yaml b/experiment/config/dataset/xyobject_grey.yaml index 51b23f00..aeb92718 100644 --- a/experiment/config/dataset/xyobject_grey.yaml +++ b/experiment/config/dataset/xyobject_grey.yaml @@ -1,16 +1,18 @@ -# @package _group_ +defaults: + - _data_type_: gt + name: xyobject_grey + data: _target_: disent.dataset.data.XYObjectData - grid_size: 64 - grid_spacing: 1 - min_square_size: 3 - max_square_size: 9 - square_size_spacing: 2 rgb: FALSE - palette: 'white' + transform: - _target_: disent.nn.transform.ToStandardisedTensor -x_shape: [1, 64, 64] + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} -data_type: ground_truth +meta: + x_shape: [1, 64, 64] + vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" + vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" diff --git a/experiment/config/dataset/xyobject_shaded.yaml b/experiment/config/dataset/xyobject_shaded.yaml new file mode 100644 index 00000000..5490ad2d --- /dev/null +++ b/experiment/config/dataset/xyobject_shaded.yaml @@ -0,0 +1,18 @@ +defaults: + - _data_type_: gt + +name: xyobject_shaded + +data: + _target_: disent.dataset.data.XYObjectShadedData + rgb: TRUE + +transform: + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} + +meta: + x_shape: [3, 64, 64] + vis_mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] + vis_std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] diff --git a/experiment/config/dataset/xyobject_shaded_grey.yaml b/experiment/config/dataset/xyobject_shaded_grey.yaml new file mode 100644 index 00000000..4993dafd --- /dev/null +++ b/experiment/config/dataset/xyobject_shaded_grey.yaml @@ -0,0 +1,18 @@ +defaults: + - _data_type_: gt + +name: xyobject_shaded_grey + +data: + _target_: disent.dataset.data.XYObjectShadedData + rgb: FALSE + +transform: + _target_: disent.dataset.transform.ToImgTensorF32 + mean: ${dataset.meta.vis_mean} + std: ${dataset.meta.vis_std} + +meta: + x_shape: [1, 64, 64] + vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" + vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" diff --git a/experiment/config/dataset_sampler/ground_truth_pair.yaml b/experiment/config/dataset_sampler/ground_truth_pair.yaml deleted file mode 100644 index 9ff80554..00000000 --- a/experiment/config/dataset_sampler/ground_truth_pair.yaml +++ /dev/null @@ -1,8 +0,0 @@ -# @package _group_ -name: ground_truth_pair -sampler: - _target_: disent.dataset.sampling.GroundTruthPairSampler - # factor sampling - p_k_range: ${dataset_sampling.k} - # radius sampling - p_radius_range: ${dataset_sampling.k_radius} diff --git a/experiment/config/dataset_sampler/ground_truth_triplet.yaml b/experiment/config/dataset_sampler/ground_truth_triplet.yaml deleted file mode 100644 index 99a5feaa..00000000 --- a/experiment/config/dataset_sampler/ground_truth_triplet.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# @package _group_ -name: ground_truth_triplet -sampler: - _target_: disent.dataset.sampling.GroundTruthTripleSampler - # factor sampling - p_k_range: ${dataset_sampling.k} - n_k_range: ${dataset_sampling.n_k} - n_k_sample_mode: ${dataset_sampling.n_k_mode} - n_k_is_shared: TRUE - # radius sampling - p_radius_range: ${dataset_sampling.k_radius} - n_radius_range: ${dataset_sampling.n_k_radius} - n_radius_sample_mode: ${dataset_sampling.n_k_radius_mode} - # final checks - swap_metric: ${dataset_sampling.swap_metric} - swap_chance: ${dataset_sampling.swap_chance} diff --git a/experiment/config/dataset_sampler/ground_truth_weak_pair.yaml b/experiment/config/dataset_sampler/ground_truth_weak_pair.yaml deleted file mode 100644 index edb8cc96..00000000 --- a/experiment/config/dataset_sampler/ground_truth_weak_pair.yaml +++ /dev/null @@ -1,6 +0,0 @@ -# @package _group_ -name: ground_truth_weak_pair -sampler: - _target_: disent.dataset.sampling.GroundTruthPairOrigSampler - # factor sampling - p_k: ${dataset_sampling.k.1} diff --git a/experiment/config/dataset_sampler/gt_dist_pair.yaml b/experiment/config/dataset_sampler/gt_dist_pair.yaml deleted file mode 100644 index 63b0156d..00000000 --- a/experiment/config/dataset_sampler/gt_dist_pair.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _group_ -name: gt_dist_pair -sampler: - _target_: disent.dataset.sampling.GroundTruthDistSampler - num_samples: 2 - triplet_sample_mode: ${dataset_sampling.triplet_sample_mode} # random, factors, manhattan, combined - triplet_swap_chance: ${dataset_sampling.triplet_swap_chance} diff --git a/experiment/config/dataset_sampler/gt_dist_single.yaml b/experiment/config/dataset_sampler/gt_dist_single.yaml deleted file mode 100644 index 0e207ae0..00000000 --- a/experiment/config/dataset_sampler/gt_dist_single.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _group_ -name: gt_dist_single -sampler: - _target_: disent.dataset.sampling.GroundTruthDistSampler - num_samples: 1 - triplet_sample_mode: ${dataset_sampling.triplet_sample_mode} # random, factors, manhattan, combined - triplet_swap_chance: ${dataset_sampling.triplet_swap_chance} diff --git a/experiment/config/dataset_sampler/gt_dist_triplet.yaml b/experiment/config/dataset_sampler/gt_dist_triplet.yaml deleted file mode 100644 index 56528a65..00000000 --- a/experiment/config/dataset_sampler/gt_dist_triplet.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# @package _group_ -name: gt_dist_single -sampler: - _target_: disent.dataset.sampling.GroundTruthDistSampler - num_samples: 3 - triplet_sample_mode: ${dataset_sampling.triplet_sample_mode} # random, factors, manhattan, combined - triplet_swap_chance: ${dataset_sampling.triplet_swap_chance} diff --git a/experiment/config/dataset_sampling/full_bb.yaml b/experiment/config/dataset_sampling/full_bb.yaml deleted file mode 100644 index 6b0b8871..00000000 --- a/experiment/config/dataset_sampling/full_bb.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# @package _global_ -dataset_sampling: - name: full_bb - # varying factors (if applicable for pairs) -- sample in range: [min, max] - k: [0, -1] - k_radius: [0, -1] - # varying factors (if applicable for triplets) -- sample in range: [min, max] - n_k: [0, -1] - n_k_mode: 'bounded_below' - n_k_radius: [0, -1] - n_k_radius_mode: 'bounded_below' - # swap incorrect samples - swap_metric: NULL - # swap positive and negative if possible - swap_chance: NULL diff --git a/experiment/config/dataset_sampling/full_ran_l1.yaml b/experiment/config/dataset_sampling/full_ran_l1.yaml deleted file mode 100644 index c5cf0bfe..00000000 --- a/experiment/config/dataset_sampling/full_ran_l1.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# @package _global_ -dataset_sampling: - name: full_ran_l1 - # varying factors (if applicable for pairs) -- sample in range: [min, max] - k: [0, -1] - k_radius: [0, -1] - # varying factors (if applicable for triplets) -- sample in range: [min, max] - n_k: [0, -1] - n_k_mode: 'random' - n_k_radius: [0, -1] - n_k_radius_mode: 'random' - # swap incorrect samples - swap_metric: 'manhattan' - # swap positive and negative if possible - swap_chance: NULL diff --git a/experiment/config/dataset_sampling/full_ran_l2.yaml b/experiment/config/dataset_sampling/full_ran_l2.yaml deleted file mode 100644 index 43c5aef1..00000000 --- a/experiment/config/dataset_sampling/full_ran_l2.yaml +++ /dev/null @@ -1,15 +0,0 @@ -# @package _global_ -dataset_sampling: - name: full_ran_l2 - # varying factors (if applicable for pairs) -- sample in range: [min, max] - k: [0, -1] - k_radius: [0, -1] - # varying factors (if applicable for triplets) -- sample in range: [min, max] - n_k: [0, -1] - n_k_mode: 'random' - n_k_radius: [0, -1] - n_k_radius_mode: 'random' - # swap incorrect samples - swap_metric: 'euclidean' - # swap positive and negative if possible - swap_chance: NULL diff --git a/experiment/config/dataset_sampling/gt_dist_combined.yaml b/experiment/config/dataset_sampling/gt_dist_combined.yaml deleted file mode 100644 index 2961d8f2..00000000 --- a/experiment/config/dataset_sampling/gt_dist_combined.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_combined - triplet_sample_mode: "combined" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/dataset_sampling/gt_dist_combined_scaled.yaml b/experiment/config/dataset_sampling/gt_dist_combined_scaled.yaml deleted file mode 100644 index 326031da..00000000 --- a/experiment/config/dataset_sampling/gt_dist_combined_scaled.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_combined_scaled - triplet_sample_mode: "combined_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/dataset_sampling/gt_dist_factors.yaml b/experiment/config/dataset_sampling/gt_dist_factors.yaml deleted file mode 100644 index c2821f41..00000000 --- a/experiment/config/dataset_sampling/gt_dist_factors.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_factors - triplet_sample_mode: "factors" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/dataset_sampling/gt_dist_manhat.yaml b/experiment/config/dataset_sampling/gt_dist_manhat.yaml deleted file mode 100644 index ae7cc831..00000000 --- a/experiment/config/dataset_sampling/gt_dist_manhat.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_manhat - triplet_sample_mode: "manhattan" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/dataset_sampling/gt_dist_manhat_scaled.yaml b/experiment/config/dataset_sampling/gt_dist_manhat_scaled.yaml deleted file mode 100644 index a2e7f8c8..00000000 --- a/experiment/config/dataset_sampling/gt_dist_manhat_scaled.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_manhat_scaled - triplet_sample_mode: "manhattan_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/dataset_sampling/gt_dist_random.yaml b/experiment/config/dataset_sampling/gt_dist_random.yaml deleted file mode 100644 index 64521891..00000000 --- a/experiment/config/dataset_sampling/gt_dist_random.yaml +++ /dev/null @@ -1,5 +0,0 @@ -# @package _global_ -dataset_sampling: - name: dist_random - triplet_sample_mode: "random" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled - triplet_swap_chance: 0 diff --git a/experiment/config/framework/_input_mode_/pair.yaml b/experiment/config/framework/_input_mode_/pair.yaml new file mode 100644 index 00000000..c0596449 --- /dev/null +++ b/experiment/config/framework/_input_mode_/pair.yaml @@ -0,0 +1,3 @@ +# controlled by the framework's defaults list +name: pair +num: 2 diff --git a/experiment/config/framework/_input_mode_/single.yaml b/experiment/config/framework/_input_mode_/single.yaml new file mode 100644 index 00000000..4ac6b0a7 --- /dev/null +++ b/experiment/config/framework/_input_mode_/single.yaml @@ -0,0 +1,3 @@ +# controlled by the framework's defaults list +name: single +num: 1 diff --git a/experiment/config/framework/_input_mode_/triple.yaml b/experiment/config/framework/_input_mode_/triple.yaml new file mode 100644 index 00000000..40f44980 --- /dev/null +++ b/experiment/config/framework/_input_mode_/triple.yaml @@ -0,0 +1,3 @@ +# controlled by the framework's defaults list +name: triple +num: 3 diff --git a/experiment/config/framework/_input_mode_/weak_pair.yaml b/experiment/config/framework/_input_mode_/weak_pair.yaml new file mode 100644 index 00000000..cd3a0a26 --- /dev/null +++ b/experiment/config/framework/_input_mode_/weak_pair.yaml @@ -0,0 +1,3 @@ +# controlled by the framework's defaults list +name: weak_pair +num: 2 diff --git a/experiment/config/framework/adagvae_minimal_os.yaml b/experiment/config/framework/adagvae_minimal_os.yaml new file mode 100644 index 00000000..e7bcbf13 --- /dev/null +++ b/experiment/config/framework/adagvae_minimal_os.yaml @@ -0,0 +1,23 @@ +defaults: + - _input_mode_: weak_pair + +name: adagvae_minimal_os + +cfg: + _target_: disent.frameworks.vae.AdaGVaeMinimal.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} + # base vae + latent_distribution: ${settings.framework_opt.latent_distribution} + # disable various components + disable_decoder: FALSE + disable_reg_loss: FALSE + disable_rec_loss: FALSE + disable_aug_loss: FALSE + disable_posterior_scale: NULL + # Beta-VAE + beta: ${settings.framework.beta} + +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/adavae.yaml b/experiment/config/framework/adavae.yaml index 5e62ef61..d234dcbf 100644 --- a/experiment/config/framework/adavae.yaml +++ b/experiment/config/framework/adavae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: pair + name: adavae -module: - _target_: disent.frameworks.vae.AdaVae + +cfg: + _target_: disent.frameworks.vae.AdaVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,12 +17,11 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} # adavae ada_average_mode: gvae # gvae or ml-vae ada_thresh_mode: symmetric_kl ada_thresh_ratio: 0.5 -# settings used elsewhere -data_sample_mode: pair -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/adavae_os.yaml b/experiment/config/framework/adavae_os.yaml index 2588c36f..d8054386 100644 --- a/experiment/config/framework/adavae_os.yaml +++ b/experiment/config/framework/adavae_os.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: weak_pair + name: adavae_os -module: - _target_: disent.frameworks.vae.AdaVae + +cfg: + _target_: disent.frameworks.vae.AdaVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,12 +17,11 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} # adavae ada_average_mode: gvae # gvae or ml-vae ada_thresh_mode: symmetric_kl ada_thresh_ratio: 0.5 -# settings used elsewhere -data_sample_mode: weak_pair -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/ae.yaml b/experiment/config/framework/ae.yaml index a21a9e84..cd410b6b 100644 --- a/experiment/config/framework/ae.yaml +++ b/experiment/config/framework/ae.yaml @@ -1,12 +1,17 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: ae -module: - _target_: disent.frameworks.ae.Ae + +cfg: + _target_: disent.frameworks.ae.Ae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # disable various components disable_decoder: FALSE disable_rec_loss: FALSE disable_aug_loss: FALSE -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 1 +meta: + model_z_multiplier: 1 diff --git a/experiment/config/framework/betatcvae.yaml b/experiment/config/framework/betatcvae.yaml index 138dd9f3..32005237 100644 --- a/experiment/config/framework/betatcvae.yaml +++ b/experiment/config/framework/betatcvae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: betatcvae -module: - _target_: disent.frameworks.vae.BetaTcVae + +cfg: + _target_: disent.frameworks.vae.BetaTcVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,8 +17,7 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-TcVae - beta: ${framework.beta} + beta: ${settings.framework.beta} -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/betavae.yaml b/experiment/config/framework/betavae.yaml index 5e4f9c8f..eb0f1540 100644 --- a/experiment/config/framework/betavae.yaml +++ b/experiment/config/framework/betavae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: betavae -module: - _target_: disent.frameworks.vae.BetaVae + +cfg: + _target_: disent.frameworks.vae.BetaVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,8 +17,7 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/dfcvae.yaml b/experiment/config/framework/dfcvae.yaml index a12f913b..9a242f1d 100644 --- a/experiment/config/framework/dfcvae.yaml +++ b/experiment/config/framework/dfcvae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: dfcvae -module: - _target_: disent.frameworks.vae.DfcVae + +cfg: + _target_: disent.frameworks.vae.DfcVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,11 +17,10 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} # dfcvae feature_layers: ['14', '24', '34', '43'] feature_inputs_mode: 'none' # none, clamp, assert -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/dipvae.yaml b/experiment/config/framework/dipvae.yaml index 678cb192..4efebcf4 100644 --- a/experiment/config/framework/dipvae.yaml +++ b/experiment/config/framework/dipvae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: dipvae -module: - _target_: disent.frameworks.vae.DipVae + +cfg: + _target_: disent.frameworks.vae.DipVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,13 +17,12 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} # DIP-VAE dip_mode: 'ii' # "i" or "ii" dip_beta: 1.0 lambda_d: 1.0 # diagonal weight lambda_od: 0.5 # off diagonal weight -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/infovae.yaml b/experiment/config/framework/infovae.yaml index 89e4ec99..e9b8234b 100644 --- a/experiment/config/framework/infovae.yaml +++ b/experiment/config/framework/infovae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: infovae -module: - _target_: disent.frameworks.vae.InfoVae + +cfg: + _target_: disent.frameworks.vae.InfoVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -17,6 +23,5 @@ module: info_lambda: 5.0 info_kernel: "rbf" # rbf kernel is the only kernel currently -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/tae.yaml b/experiment/config/framework/tae.yaml index 4c627407..39fffd61 100644 --- a/experiment/config/framework/tae.yaml +++ b/experiment/config/framework/tae.yaml @@ -1,7 +1,13 @@ -# @package _group_ +defaults: + - _input_mode_: triplet + name: tae -module: - _target_: disent.frameworks.ae.TripletAe + +cfg: + _target_: disent.frameworks.ae.TripletAe.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # disable various components disable_decoder: FALSE disable_rec_loss: FALSE @@ -13,6 +19,5 @@ module: triplet_scale: 0.1 triplet_p: 1 -# settings used elsewhere -data_sample_mode: triplet -model_z_multiplier: 1 +meta: + model_z_multiplier: 1 diff --git a/experiment/config/framework/tvae.yaml b/experiment/config/framework/tvae.yaml index 6dfc7a9e..601c52d8 100644 --- a/experiment/config/framework/tvae.yaml +++ b/experiment/config/framework/tvae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: triplet + name: tvae -module: - _target_: disent.frameworks.vae.TripletVae + +cfg: + _target_: disent.frameworks.vae.TripletVae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,7 +17,7 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL # Beta-VAE - beta: ${framework.beta} + beta: ${settings.framework.beta} # tvae: triplet stuffs triplet_loss: triplet triplet_margin_min: 0.001 @@ -19,6 +25,5 @@ module: triplet_scale: 0.1 triplet_p: 1 -# settings used elsewhere -data_sample_mode: triplet -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/framework/vae.yaml b/experiment/config/framework/vae.yaml index 535869ae..31ec83c4 100644 --- a/experiment/config/framework/vae.yaml +++ b/experiment/config/framework/vae.yaml @@ -1,9 +1,15 @@ -# @package _group_ +defaults: + - _input_mode_: single + name: vae -module: - _target_: disent.frameworks.vae.Vae + +cfg: + _target_: disent.frameworks.vae.Vae.cfg + # base ae + recon_loss: ${settings.framework.recon_loss} + loss_reduction: ${settings.framework.loss_reduction} # base vae - latent_distribution: ${framework.optional.latent_distribution} + latent_distribution: ${settings.framework_opt.latent_distribution} # disable various components disable_decoder: FALSE disable_reg_loss: FALSE @@ -11,6 +17,5 @@ module: disable_aug_loss: FALSE disable_posterior_scale: NULL -# settings used elsewhere -data_sample_mode: single -model_z_multiplier: 2 +meta: + model_z_multiplier: 2 diff --git a/experiment/config/metrics/all.yaml b/experiment/config/metrics/all.yaml index f260880a..cc554da5 100644 --- a/experiment/config/metrics/all.yaml +++ b/experiment/config/metrics/all.yaml @@ -1,14 +1,16 @@ -# @package _group_ metric_list: - - mig: - - sap: - - unsupervised: + - mig: {} + - sap: {} - dci: - every_n_steps: 3600 + every_n_steps: 7200 + on_final: TRUE - factor_vae: - every_n_steps: 3600 + every_n_steps: 7200 + on_final: TRUE + - unsupervised: {} # these are the default settings, these can be placed in the list above default_on_final: TRUE default_on_train: TRUE -default_every_n_steps: 1200 +default_every_n_steps: 2400 +default_begin_first_step: FALSE diff --git a/experiment/config/metrics/common.yaml b/experiment/config/metrics/common.yaml deleted file mode 100644 index c94ff3ef..00000000 --- a/experiment/config/metrics/common.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# @package _group_ -metric_list: - - mig: - - dci: - every_n_steps: 3600 - -# these are the default settings, these can be placed in the list above -default_on_final: TRUE -default_on_train: TRUE -default_every_n_steps: 1200 diff --git a/experiment/config/metrics/fast.yaml b/experiment/config/metrics/fast.yaml index 79e205cf..71853977 100644 --- a/experiment/config/metrics/fast.yaml +++ b/experiment/config/metrics/fast.yaml @@ -1,8 +1,10 @@ -# @package _group_ metric_list: - - mig: + - mig: {} + - sap: {} + - unsupervised: {} # these are the default settings, these can be placed in the list above default_on_final: TRUE default_on_train: TRUE -default_every_n_steps: 600 +default_every_n_steps: 2400 +default_begin_first_step: FALSE diff --git a/experiment/config/metrics/none.yaml b/experiment/config/metrics/none.yaml index 00dbe83c..25b175ea 100644 --- a/experiment/config/metrics/none.yaml +++ b/experiment/config/metrics/none.yaml @@ -1,7 +1,7 @@ -# @package _group_ metric_list: [] # these are the default settings, these can be placed in the list above default_on_final: TRUE default_on_train: TRUE default_every_n_steps: 1200 +default_begin_first_step: FALSE diff --git a/experiment/config/metrics/standard.yaml b/experiment/config/metrics/standard.yaml new file mode 100644 index 00000000..49ba02de --- /dev/null +++ b/experiment/config/metrics/standard.yaml @@ -0,0 +1,15 @@ +metric_list: + - mig: {} + - sap: {} + - dci: + every_n_steps: 7200 + on_final: TRUE + - factor_vae: + every_n_steps: 7200 + on_final: TRUE + +# these are the default settings, these can be placed in the list above +default_on_final: TRUE +default_on_train: TRUE +default_every_n_steps: 2400 +default_begin_first_step: FALSE diff --git a/experiment/config/metrics/test.yaml b/experiment/config/metrics/test.yaml index 3482b65c..698ed061 100644 --- a/experiment/config/metrics/test.yaml +++ b/experiment/config/metrics/test.yaml @@ -1,4 +1,3 @@ -# @package _group_ metric_list: - mig: every_n_steps: 112 @@ -15,3 +14,4 @@ metric_list: default_on_final: FALSE default_on_train: TRUE default_every_n_steps: 200 +default_begin_first_step: FALSE diff --git a/experiment/config/model/linear.yaml b/experiment/config/model/linear.yaml new file mode 100644 index 00000000..30e1ed1e --- /dev/null +++ b/experiment/config/model/linear.yaml @@ -0,0 +1,12 @@ +name: linear + +encoder_cls: + _target_: disent.model.ae.EncoderLinear + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + +decoder_cls: + _target_: disent.model.ae.DecoderLinear + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/model/norm_conv64.yaml b/experiment/config/model/norm_conv64.yaml index 6b4b6995..3068055b 100644 --- a/experiment/config/model/norm_conv64.yaml +++ b/experiment/config/model/norm_conv64.yaml @@ -1,18 +1,18 @@ -# @package _group_ name: norm_conv64 -weight_init: 'xavier_normal' -encoder: + +encoder_cls: _target_: disent.model.ae.EncoderConv64Norm - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} - z_multiplier: ${framework.model_z_multiplier} + x_shape: ${data.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} activation: ${model.activation} norm: ${model.norm} norm_pre_act: ${model.norm_pre_act} -decoder: - _target_: disent.model.ae.DecoderConv64Alt - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} + +decoder_cls: + _target_: disent.model.ae.DecoderConv64Norm + x_shape: ${data.meta.x_shape} + z_size: ${settings.model.z_size} activation: ${model.activation} norm: ${model.norm} norm_pre_act: ${model.norm_pre_act} diff --git a/experiment/config/model/test.yaml b/experiment/config/model/test.yaml deleted file mode 100644 index c84c08f3..00000000 --- a/experiment/config/model/test.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# @package _group_ -name: test -weight_init: 'xavier_normal' -encoder: - _target_: disent.model.ae.EncoderTest - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} - z_multiplier: ${framework.model_z_multiplier} -decoder: - _target_: disent.model.ae.DecoderTest - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} diff --git a/experiment/config/model/vae_conv64.yaml b/experiment/config/model/vae_conv64.yaml index 1b7d0f43..a05d00c0 100644 --- a/experiment/config/model/vae_conv64.yaml +++ b/experiment/config/model/vae_conv64.yaml @@ -1,12 +1,12 @@ -# @package _group_ name: vae_conv64 -weight_init: 'xavier_normal' -encoder: + +encoder_cls: _target_: disent.model.ae.EncoderConv64 - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} - z_multiplier: ${framework.model_z_multiplier} -decoder: + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + +decoder_cls: _target_: disent.model.ae.DecoderConv64 - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/model/vae_fc.yaml b/experiment/config/model/vae_fc.yaml index ce0f6152..6a68f40d 100644 --- a/experiment/config/model/vae_fc.yaml +++ b/experiment/config/model/vae_fc.yaml @@ -1,12 +1,12 @@ -# @package _group_ name: vae_fc -weight_init: 'xavier_normal' -encoder: + +encoder_cls: _target_: disent.model.ae.EncoderFC - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} - z_multiplier: ${framework.model_z_multiplier} -decoder: + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} + z_multiplier: ${framework.meta.model_z_multiplier} + +decoder_cls: _target_: disent.model.ae.DecoderFC - x_shape: ${dataset.x_shape} - z_size: ${model.z_size} + x_shape: ${dataset.meta.x_shape} + z_size: ${settings.model.z_size} diff --git a/experiment/config/optimizer/adabelief.yaml b/experiment/config/optimizer/adabelief.yaml index db3d3b73..50373136 100644 --- a/experiment/config/optimizer/adabelief.yaml +++ b/experiment/config/optimizer/adabelief.yaml @@ -1,12 +1,15 @@ -# @package framework.module -optimizer: torch_optimizer.AdaBelief -optimizer_kwargs: - lr: ${optimizer.lr} - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0 +# @package _global_ - amsgrad: False - weight_decouple: False - fixed_decay: False - rectify: False +framework: + cfg: + optimizer: torch_optimizer.AdaBelief + optimizer_kwargs: + lr: ${settings.optimizer.lr} + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0 + + amsgrad: False + weight_decouple: False + fixed_decay: False + rectify: False diff --git a/experiment/config/optimizer/adam.yaml b/experiment/config/optimizer/adam.yaml index 686a12c4..1fe4cf18 100644 --- a/experiment/config/optimizer/adam.yaml +++ b/experiment/config/optimizer/adam.yaml @@ -1,9 +1,12 @@ -# @package framework.module -optimizer: torch.optim.Adam -optimizer_kwargs: - lr: ${optimizer.lr} - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0 +# @package _global_ - amsgrad: False +framework: + cfg: + optimizer: torch.optim.Adam + optimizer_kwargs: + lr: ${settings.optimizer.lr} + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0 + + amsgrad: False diff --git a/experiment/config/optimizer/amsgrad.yaml b/experiment/config/optimizer/amsgrad.yaml index ead824ca..8e3e5181 100644 --- a/experiment/config/optimizer/amsgrad.yaml +++ b/experiment/config/optimizer/amsgrad.yaml @@ -1,9 +1,12 @@ -# @package framework.module -optimizer: torch.optim.Adam -optimizer_kwargs: - lr: ${optimizer.lr} - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0 +# @package _global_ - amsgrad: True +framework: + cfg: + optimizer: torch.optim.Adam + optimizer_kwargs: + lr: ${settings.optimizer.lr} + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0 + + amsgrad: True diff --git a/experiment/config/optimizer/radam.yaml b/experiment/config/optimizer/radam.yaml index 6ecfa5e8..8fc7635f 100644 --- a/experiment/config/optimizer/radam.yaml +++ b/experiment/config/optimizer/radam.yaml @@ -1,7 +1,10 @@ -# @package framework.module -optimizer: torch_optimizer.RAdam -optimizer_kwargs: - lr: ${optimizer.lr} - betas: [0.9, 0.999] - eps: 1e-8 - weight_decay: 0 +# @package _global_ + +framework: + cfg: + optimizer: torch_optimizer.RAdam + optimizer_kwargs: + lr: ${settings.optimizer.lr} + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 0 diff --git a/experiment/config/optimizer/rmsprop.yaml b/experiment/config/optimizer/rmsprop.yaml index 42e876c9..9b150407 100644 --- a/experiment/config/optimizer/rmsprop.yaml +++ b/experiment/config/optimizer/rmsprop.yaml @@ -1,10 +1,13 @@ -# @package framework.module -optimizer: torch.optim.RMSprop -optimizer_kwargs: - lr: ${optimizer.lr} # default was 1e-2 - alpha: 0.99 - eps: 1e-8 - weight_decay: 0 +# @package _global_ - momentum: 0 - centered: False +framework: + cfg: + optimizer: torch.optim.RMSprop + optimizer_kwargs: + lr: ${settings.optimizer.lr} # default was 1e-2 + alpha: 0.99 + eps: 1e-8 + weight_decay: 0 + + momentum: 0 + centered: False diff --git a/experiment/config/optimizer/sgd.yaml b/experiment/config/optimizer/sgd.yaml index 1dfe53da..07877232 100644 --- a/experiment/config/optimizer/sgd.yaml +++ b/experiment/config/optimizer/sgd.yaml @@ -1,8 +1,11 @@ -# @package framework.module -optimizer: torch.optim.SGD -optimizer_kwargs: - lr: ${optimizer.lr} - momentum: 0 - dampening: 0 - weight_decay: 0 - nesterov: False +# @package _global_ + +framework: + cfg: + optimizer: torch.optim.SGD + optimizer_kwargs: + lr: ${settings.optimizer.lr} + momentum: 0 + dampening: 0 + weight_decay: 0 + nesterov: False diff --git a/experiment/config/run_action/prepare_data.yaml b/experiment/config/run_action/prepare_data.yaml new file mode 100644 index 00000000..455efa6b --- /dev/null +++ b/experiment/config/run_action/prepare_data.yaml @@ -0,0 +1,8 @@ +# @package _global_ +action: prepare_data + +# override settings from job/location +dsettings: + dataset: + try_in_memory: FALSE + prepare: TRUE diff --git a/experiment/config/run_action/train.yaml b/experiment/config/run_action/train.yaml new file mode 100644 index 00000000..0f1a4ebe --- /dev/null +++ b/experiment/config/run_action/train.yaml @@ -0,0 +1,2 @@ +# @package _global_ +action: train diff --git a/experiment/config/run_callbacks/all.yaml b/experiment/config/run_callbacks/all.yaml index d273971a..fd4c88a6 100644 --- a/experiment/config/run_callbacks/all.yaml +++ b/experiment/config/run_callbacks/all.yaml @@ -1,9 +1,42 @@ # @package _global_ + callbacks: latent_cycle: seed: 7777 - every_n_steps: 1200 + every_n_steps: 3600 mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' - correlation: - repeats_per_factor: 10 - every_n_steps: 7200 + begin_first_step: TRUE + + gt_dists: + seed: 7777 + every_n_steps: 3600 + traversal_repeats: 100 + begin_first_step: TRUE + +# correlation: +# repeats_per_factor: 10 +# every_n_steps: 7200 + +# latent_cycle: +# _target_: disent.util.lightning.callbacks.VaeLatentCycleLoggingCallback +# seed: 7777 +# every_n_steps: 7200 +# begin_first_step: FALSE +# mode: 'minmax_interval_cycle' # minmax_interval_cycle, fitted_gaussian_cycle +# recon_min: ${dataset.vis_min} +# recon_max: ${dataset.vis_max} +# recon_mean: ${dataset.vis_mean} +# recon_std: ${dataset.vis_std} +# +# gt_dists: +# _target_: disent.util.lightning.callbacks.VaeGtDistsLoggingCallback +# seed: 7777 +# every_n_steps: 7200 +# begin_first_step: TRUE +# traversal_repeats: 100 +# plt_block_size: 1.25 +# plt_show: False +# plt_transpose: False +# log_wandb: TRUE +# batch_size: ${cfg.dataset.batch_size} +# include_factor_dists: TRUE diff --git a/experiment/config/run_callbacks/none.yaml b/experiment/config/run_callbacks/none.yaml index 61fb387d..5ceab7b5 100644 --- a/experiment/config/run_callbacks/none.yaml +++ b/experiment/config/run_callbacks/none.yaml @@ -1,2 +1,4 @@ # @package _global_ + callbacks: + # empty! diff --git a/experiment/config/run_callbacks/test.yaml b/experiment/config/run_callbacks/test.yaml index 16587eb6..255cda8a 100644 --- a/experiment/config/run_callbacks/test.yaml +++ b/experiment/config/run_callbacks/test.yaml @@ -1,9 +1,18 @@ # @package _global_ + callbacks: latent_cycle: seed: 7777 - every_n_steps: 100 + every_n_steps: 3 mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' - correlation: - repeats_per_factor: 10 - every_n_steps: 102 + begin_first_step: FALSE + + gt_dists: + seed: 7777 + every_n_steps: 4 + traversal_repeats: 3 + begin_first_step: FALSE + +# correlation: +# repeats_per_factor: 3 +# every_n_steps: 5 diff --git a/experiment/config/run_callbacks/vis.yaml b/experiment/config/run_callbacks/vis.yaml index 860a90fd..2087779c 100644 --- a/experiment/config/run_callbacks/vis.yaml +++ b/experiment/config/run_callbacks/vis.yaml @@ -1,6 +1,14 @@ # @package _global_ + callbacks: latent_cycle: seed: 7777 - every_n_steps: 600 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' \ No newline at end of file + every_n_steps: 3600 + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + begin_first_step: TRUE + + gt_dists: + seed: 7777 + every_n_steps: 3600 + traversal_repeats: 100 + begin_first_step: TRUE diff --git a/experiment/config/run_callbacks/vis_fast.yaml b/experiment/config/run_callbacks/vis_fast.yaml new file mode 100644 index 00000000..3df24a15 --- /dev/null +++ b/experiment/config/run_callbacks/vis_fast.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +callbacks: + latent_cycle: + seed: 7777 + every_n_steps: 1800 + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + begin_first_step: TRUE + + gt_dists: + seed: 7777 + every_n_steps: 1800 + traversal_repeats: 100 + begin_first_step: TRUE diff --git a/experiment/config/run_callbacks/vis_slow.yaml b/experiment/config/run_callbacks/vis_slow.yaml index 75875cdc..83f0516b 100644 --- a/experiment/config/run_callbacks/vis_slow.yaml +++ b/experiment/config/run_callbacks/vis_slow.yaml @@ -1,6 +1,14 @@ # @package _global_ + callbacks: latent_cycle: seed: 7777 - every_n_steps: 3600 - mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' \ No newline at end of file + every_n_steps: 7200 + mode: 'minmax_interval_cycle' # 'minmax_interval_cycle', 'fitted_gaussian_cycle' + begin_first_step: TRUE + + gt_dists: + seed: 7777 + every_n_steps: 7200 + traversal_repeats: 100 + begin_first_step: TRUE diff --git a/experiment/config/run_launcher/local.yaml b/experiment/config/run_launcher/local.yaml new file mode 100644 index 00000000..7e8dba86 --- /dev/null +++ b/experiment/config/run_launcher/local.yaml @@ -0,0 +1,4 @@ +# @package _global_ + +defaults: + - override /hydra/launcher: basic diff --git a/experiment/config/run_launcher/slurm.yaml b/experiment/config/run_launcher/slurm.yaml new file mode 100644 index 00000000..f2d86558 --- /dev/null +++ b/experiment/config/run_launcher/slurm.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +defaults: + - override /hydra/launcher: submitit_slurm + +hydra: + launcher: + partition: ${dsettings.launcher.partition} + mem_gb: 0 + timeout_min: 1440 # minutes + submitit_folder: '${hydra.sweep.dir}/%j' + array_parallelism: ${dsettings.launcher.array_parallelism} + exclude: ${dsettings.launcher.exclude} diff --git a/experiment/config/run_length/debug.yaml b/experiment/config/run_length/debug.yaml index a0433d3c..d3f4866c 100644 --- a/experiment/config/run_length/debug.yaml +++ b/experiment/config/run_length/debug.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 10 - steps: 10 + max_epochs: 3 + max_steps: 3 diff --git a/experiment/config/run_length/epic.yaml b/experiment/config/run_length/epic.yaml index f9b7b273..9a5e22bd 100644 --- a/experiment/config/run_length/epic.yaml +++ b/experiment/config/run_length/epic.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 230400 - steps: 230400 + max_epochs: 230400 + max_steps: 230400 diff --git a/experiment/config/run_length/long.yaml b/experiment/config/run_length/long.yaml index 9c5b7a2e..6d03ec4c 100644 --- a/experiment/config/run_length/long.yaml +++ b/experiment/config/run_length/long.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 115200 - steps: 115200 + max_epochs: 115200 + max_steps: 115200 diff --git a/experiment/config/run_length/medium.yaml b/experiment/config/run_length/medium.yaml index 8ba23e7f..92dbd90e 100644 --- a/experiment/config/run_length/medium.yaml +++ b/experiment/config/run_length/medium.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 57600 - steps: 57600 + max_epochs: 57600 + max_steps: 57600 diff --git a/experiment/config/run_length/short.yaml b/experiment/config/run_length/short.yaml index 44d76fd6..b5588617 100644 --- a/experiment/config/run_length/short.yaml +++ b/experiment/config/run_length/short.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 28800 - steps: 28800 + max_epochs: 28800 + max_steps: 28800 diff --git a/experiment/config/run_length/test.yaml b/experiment/config/run_length/test.yaml index 146d0153..ce1d5870 100644 --- a/experiment/config/run_length/test.yaml +++ b/experiment/config/run_length/test.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 1 - steps: 1 + max_epochs: 5 + max_steps: 5 diff --git a/experiment/config/run_length/tiny.yaml b/experiment/config/run_length/tiny.yaml index aeea555d..d4f30284 100644 --- a/experiment/config/run_length/tiny.yaml +++ b/experiment/config/run_length/tiny.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 14400 - steps: 14400 + max_epochs: 14400 + max_steps: 14400 diff --git a/experiment/config/run_length/vtiny.yaml b/experiment/config/run_length/vtiny.yaml deleted file mode 100644 index adf34e33..00000000 --- a/experiment/config/run_length/vtiny.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# @package _global_ -trainer: - epochs: 7200 - steps: 7200 diff --git a/experiment/config/run_length/xtiny.yaml b/experiment/config/run_length/xtiny.yaml index e9128cc4..0da9f874 100644 --- a/experiment/config/run_length/xtiny.yaml +++ b/experiment/config/run_length/xtiny.yaml @@ -1,4 +1,5 @@ # @package _global_ + trainer: - epochs: 3600 - steps: 3600 + max_epochs: 7200 + max_steps: 7200 diff --git a/experiment/config/run_location/cluster.yaml b/experiment/config/run_location/cluster.yaml index a47c7bac..9194493f 100644 --- a/experiment/config/run_location/cluster.yaml +++ b/experiment/config/run_location/cluster.yaml @@ -1,30 +1,33 @@ # @package _global_ -logging: - logs_dir: 'logs' + +dsettings: + trainer: + cuda: NULL # auto-detect cuda, some nodes may be configured incorrectly + storage: + logs_dir: 'logs' + data_root: '/tmp/${oc.env:USER}/datasets' + dataset: + gpu_augment: FALSE + prepare: TRUE + try_in_memory: TRUE + launcher: + partition: batch + array_parallelism: 16 + exclude: "mscluster93,mscluster94,mscluster97,mscluster99" trainer: - cuda: NULL # auto detect cuda, some nodes are not configured correctly prepare_data_per_node: TRUE -dataset: +dataloader: num_workers: 8 - batch_size: 256 - data_root: '/tmp/${env:USER}/datasets' - pin_memory: ${trainer.cuda} - try_in_memory: FALSE - gpu_augment: FALSE + pin_memory: ${dsettings.trainer.cuda} + batch_size: ${settings.dataset.batch_size} hydra: job: name: 'disent' run: - dir: '${logging.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' sweep: - dir: '${logging.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' subdir: '${hydra.job.id}' # hydra.job.id is not available for dir - launcher: - partition: ${job.partition} - mem_gb: 0 - timeout_min: 1440 # minutes - submitit_folder: '${hydra.sweep.dir}/%j' - array_parallelism: 10 diff --git a/experiment/config/run_location/cluster_many.yaml b/experiment/config/run_location/cluster_many.yaml deleted file mode 100644 index 005c6945..00000000 --- a/experiment/config/run_location/cluster_many.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# @package _global_ -logging: - logs_dir: 'logs' - -trainer: - cuda: NULL # auto detect cuda, some nodes are not configured correctly - prepare_data_per_node: TRUE - -dataset: - num_workers: 8 - batch_size: 256 - data_root: '/tmp/${env:USER}/datasets' - pin_memory: ${trainer.cuda} - try_in_memory: FALSE - gpu_augment: FALSE - -hydra: - job: - name: 'disent' - run: - dir: '${logging.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - sweep: - dir: '${logging.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - subdir: '${hydra.job.id}' # hydra.job.id is not available for dir - launcher: - partition: ${job.partition} - mem_gb: 0 - timeout_min: 1440 # minutes - submitit_folder: '${hydra.sweep.dir}/%j' - array_parallelism: 32 diff --git a/experiment/config/run_location/local.yaml b/experiment/config/run_location/local.yaml index c41a9d7c..458501f9 100644 --- a/experiment/config/run_location/local.yaml +++ b/experiment/config/run_location/local.yaml @@ -1,24 +1,29 @@ # @package _global_ -logging: - logs_dir: 'logs' + +dsettings: + trainer: + cuda: TRUE + storage: + logs_dir: 'logs' + data_root: '/tmp/${oc.env:USER}/datasets' + dataset: + gpu_augment: FALSE + prepare: TRUE + try_in_memory: TRUE trainer: - cuda: TRUE prepare_data_per_node: TRUE -dataset: +dataloader: num_workers: 8 - batch_size: 256 - data_root: '/tmp/${env:USER}/datasets' - pin_memory: ${trainer.cuda} - try_in_memory: FALSE - gpu_augment: FALSE + pin_memory: ${dsettings.trainer.cuda} + batch_size: ${settings.dataset.batch_size} hydra: job: name: 'disent' run: - dir: '${logging.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' sweep: - dir: '${logging.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_location/local_cpu.yaml b/experiment/config/run_location/local_cpu.yaml index 9f82de47..d1bcd61a 100644 --- a/experiment/config/run_location/local_cpu.yaml +++ b/experiment/config/run_location/local_cpu.yaml @@ -1,24 +1,29 @@ # @package _global_ -logging: - logs_dir: 'logs' + +dsettings: + trainer: + cuda: FALSE + storage: + logs_dir: 'logs' + data_root: '/tmp/${oc.env:USER}/datasets' + dataset: + gpu_augment: FALSE + prepare: TRUE + try_in_memory: TRUE trainer: - cuda: FALSE prepare_data_per_node: TRUE -dataset: +dataloader: num_workers: 8 - batch_size: 256 - data_root: '/tmp/${env:USER}/datasets' - pin_memory: ${trainer.cuda} - try_in_memory: FALSE - gpu_augment: FALSE + pin_memory: ${dsettings.trainer.cuda} + batch_size: ${settings.dataset.batch_size} hydra: job: name: 'disent' run: - dir: '${logging.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' sweep: - dir: '${logging.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_location/stampede_shr.yaml b/experiment/config/run_location/stampede_shr.yaml new file mode 100644 index 00000000..05ba8626 --- /dev/null +++ b/experiment/config/run_location/stampede_shr.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +dsettings: + trainer: + cuda: NULL # auto-detect cuda, some nodes are not configured correctly + storage: + logs_dir: 'logs' + data_root: '${oc.env:HOME}/downloads/datasets' # WE NEED TO BE VERY CAREFUL ABOUT USING A SHARED DRIVE + dataset: + gpu_augment: FALSE + prepare: FALSE # WE MUST PREPARE DATA MANUALLY BEFOREHAND + try_in_memory: TRUE + launcher: + partition: stampede + array_parallelism: 16 + exclude: "mscluster93,mscluster94,mscluster97,mscluster99" + +trainer: + prepare_data_per_node: TRUE + +dataloader: + num_workers: 16 + pin_memory: ${dsettings.trainer.cuda} + batch_size: ${settings.dataset.batch_size} + +hydra: + job: + name: 'disent' + run: + dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + sweep: + dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_location/stampede_tmp.yaml b/experiment/config/run_location/stampede_tmp.yaml new file mode 100644 index 00000000..d596b335 --- /dev/null +++ b/experiment/config/run_location/stampede_tmp.yaml @@ -0,0 +1,33 @@ +# @package _global_ + +dsettings: + trainer: + cuda: NULL # auto-detect cuda, some nodes are not configured correctly + storage: + logs_dir: 'logs' + data_root: '/tmp/${oc.env:USER}/datasets' + dataset: + gpu_augment: FALSE + prepare: TRUE + try_in_memory: TRUE + launcher: + partition: stampede + array_parallelism: 16 + exclude: "mscluster93,mscluster94,mscluster97,mscluster99" + +trainer: + prepare_data_per_node: TRUE + +dataloader: + num_workers: 16 + pin_memory: ${dsettings.trainer.cuda} + batch_size: ${settings.dataset.batch_size} + +hydra: + job: + name: 'disent' + run: + dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + sweep: + dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' + subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_logging/none.yaml b/experiment/config/run_logging/none.yaml index a2862c6e..2a99d3c4 100644 --- a/experiment/config/run_logging/none.yaml +++ b/experiment/config/run_logging/none.yaml @@ -1,12 +1,18 @@ # @package _global_ + +defaults: + - override /hydra/job_logging: colorlog + - override /hydra/hydra_logging: colorlog + +trainer: + log_every_n_steps: 50 + flush_logs_every_n_steps: 100 + progress_bar_refresh_rate: 0 # disable the builtin progress bar + callbacks: progress: interval: 5 logging: - # logging frequency in number of batchers, default: log_every_n_steps=50, flush_logs_every_n_steps=100 - log_every_n_steps: 50 - flush_logs_every_n_steps: 100 - # log to online service wandb: enabled: FALSE diff --git a/experiment/config/run_logging/wandb.yaml b/experiment/config/run_logging/wandb.yaml index 3bd6c1d6..a72cb887 100644 --- a/experiment/config/run_logging/wandb.yaml +++ b/experiment/config/run_logging/wandb.yaml @@ -1,18 +1,24 @@ # @package _global_ + +defaults: + - override /hydra/job_logging: colorlog + - override /hydra/hydra_logging: colorlog + +trainer: + log_every_n_steps: 100 + flush_logs_every_n_steps: 200 + progress_bar_refresh_rate: 0 # disable the builtin progress bar + callbacks: progress: - interval: 30 + interval: 15 logging: - # logging frequency in number of batchers, default: log_every_n_steps=50, flush_logs_every_n_steps=100 - log_every_n_steps: 100 - flush_logs_every_n_steps: 200 - # log to online service wandb: enabled: TRUE offline: FALSE - entity: '${job.user}' - project: '${job.project}' - name: '${job.name}' + entity: '${settings.job.user}' + project: '${settings.job.project}' + name: '${settings.job.name}' group: null tags: [] diff --git a/experiment/config/run_logging/wandb_fast.yaml b/experiment/config/run_logging/wandb_fast.yaml index c9adbcfd..75c305ee 100644 --- a/experiment/config/run_logging/wandb_fast.yaml +++ b/experiment/config/run_logging/wandb_fast.yaml @@ -1,18 +1,24 @@ # @package _global_ + +defaults: + - override /hydra/job_logging: colorlog + - override /hydra/hydra_logging: colorlog + +trainer: + log_every_n_steps: 50 + flush_logs_every_n_steps: 100 + progress_bar_refresh_rate: 0 # disable the builtin progress bar + callbacks: progress: interval: 5 logging: - # logging frequency in number of batchers, default: log_every_n_steps=50, flush_logs_every_n_steps=100 - log_every_n_steps: 50 - flush_logs_every_n_steps: 100 - # log to online service wandb: enabled: TRUE offline: FALSE - entity: '${job.user}' - project: '${job.project}' - name: '${job.name}' + entity: '${settings.job.user}' + project: '${settings.job.project}' + name: '${settings.job.name}' group: null tags: [] diff --git a/experiment/config/run_logging/wandb_fast_offline.yaml b/experiment/config/run_logging/wandb_fast_offline.yaml new file mode 100644 index 00000000..372e480a --- /dev/null +++ b/experiment/config/run_logging/wandb_fast_offline.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +defaults: + - override /hydra/job_logging: colorlog + - override /hydra/hydra_logging: colorlog + +trainer: + log_every_n_steps: 50 + flush_logs_every_n_steps: 100 + progress_bar_refresh_rate: 0 # disable the builtin progress bar + +callbacks: + progress: + interval: 5 + +logging: + wandb: + enabled: TRUE + offline: TRUE + entity: '${settings.job.user}' + project: '${settings.job.project}' + name: '${settings.job.name}' + group: null + tags: [] diff --git a/experiment/config/run_logging/wandb_slow.yaml b/experiment/config/run_logging/wandb_slow.yaml index 7db1c013..f7f4a49c 100644 --- a/experiment/config/run_logging/wandb_slow.yaml +++ b/experiment/config/run_logging/wandb_slow.yaml @@ -1,18 +1,24 @@ # @package _global_ + +defaults: + - override /hydra/job_logging: colorlog + - override /hydra/hydra_logging: colorlog + +trainer: + log_every_n_steps: 200 + flush_logs_every_n_steps: 400 + progress_bar_refresh_rate: 0 # disable the builtin progress bar + callbacks: progress: interval: 30 logging: - # logging frequency in number of batchers, default: log_every_n_steps=50, flush_logs_every_n_steps=100 - log_every_n_steps: 400 - flush_logs_every_n_steps: 800 - # log to online service wandb: enabled: TRUE offline: FALSE - entity: '${job.user}' - project: '${job.project}' - name: '${job.name}' + entity: '${settings.job.user}' + project: '${settings.job.project}' + name: '${settings.job.name}' group: null tags: [] diff --git a/experiment/config/dataset_sampler/episodes_pair.yaml b/experiment/config/sampling/_sampler_/episodes__pair.yaml similarity index 78% rename from experiment/config/dataset_sampler/episodes_pair.yaml rename to experiment/config/sampling/_sampler_/episodes__pair.yaml index c5cb76c7..a0d78b12 100644 --- a/experiment/config/dataset_sampler/episodes_pair.yaml +++ b/experiment/config/sampling/_sampler_/episodes__pair.yaml @@ -1,6 +1,6 @@ -# @package _group_ -name: episode_pair -sampler: +name: episodes__pair + +sampler_cls: _target_: disent.dataset.sampling.RandomEpisodeSampler num_samples: 2 # TODO: this needs to be updated to use the same API as ground_truth wrappers. diff --git a/experiment/config/dataset_sampler/episodes_single.yaml b/experiment/config/sampling/_sampler_/episodes__single.yaml similarity index 78% rename from experiment/config/dataset_sampler/episodes_single.yaml rename to experiment/config/sampling/_sampler_/episodes__single.yaml index 42829f37..d71d5b50 100644 --- a/experiment/config/dataset_sampler/episodes_single.yaml +++ b/experiment/config/sampling/_sampler_/episodes__single.yaml @@ -1,6 +1,6 @@ -# @package _group_ -name: episode_single -sampler: +name: episodes__single + +sampler_cls: _target_: disent.dataset.sampling.RandomEpisodeSampler num_samples: 1 # TODO: this needs to be updated to use the same API as ground_truth wrappers. diff --git a/experiment/config/dataset_sampler/episodes_triplet.yaml b/experiment/config/sampling/_sampler_/episodes__triplet.yaml similarity index 77% rename from experiment/config/dataset_sampler/episodes_triplet.yaml rename to experiment/config/sampling/_sampler_/episodes__triplet.yaml index b79f3457..ff16a907 100644 --- a/experiment/config/dataset_sampler/episodes_triplet.yaml +++ b/experiment/config/sampling/_sampler_/episodes__triplet.yaml @@ -1,6 +1,6 @@ -# @package _group_ -name: episode_triplet -sampler: +name: episodes__triplet + +sampler_cls: _target_: disent.dataset.sampling.RandomEpisodeSampler num_samples: 3 # TODO: this needs to be updated to use the same API as ground_truth wrappers. diff --git a/experiment/config/dataset_sampler/episodes_weak_pair.yaml b/experiment/config/sampling/_sampler_/episodes__weak_pair.yaml similarity index 89% rename from experiment/config/dataset_sampler/episodes_weak_pair.yaml rename to experiment/config/sampling/_sampler_/episodes__weak_pair.yaml index 4d69cda3..4f4ed16f 100644 --- a/experiment/config/dataset_sampler/episodes_weak_pair.yaml +++ b/experiment/config/sampling/_sampler_/episodes__weak_pair.yaml @@ -1,6 +1,6 @@ -# @package _group_ -name: episode_pair -sampler: +name: episodes__weak_pair + +sampler_cls: _target_: disent.dataset.sampling.RandomEpisodeSampler num_samples: 2 # TODO: this needs to be updated to use the same API as ground_truth wrappers. diff --git a/experiment/config/sampling/_sampler_/gt__pair.yaml b/experiment/config/sampling/_sampler_/gt__pair.yaml new file mode 100644 index 00000000..53a02837 --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt__pair.yaml @@ -0,0 +1,8 @@ +name: gt__pair + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthPairSampler + # factor sampling + p_k_range: ${sampling.k} + # radius sampling + p_radius_range: ${sampling.k_radius} diff --git a/experiment/config/dataset_sampler/ground_truth_single.yaml b/experiment/config/sampling/_sampler_/gt__single.yaml similarity index 53% rename from experiment/config/dataset_sampler/ground_truth_single.yaml rename to experiment/config/sampling/_sampler_/gt__single.yaml index e98d4b48..b3831d69 100644 --- a/experiment/config/dataset_sampler/ground_truth_single.yaml +++ b/experiment/config/sampling/_sampler_/gt__single.yaml @@ -1,4 +1,4 @@ -# @package _group_ -name: ground_truth_single -sampler: +name: gt__single + +sampler_cls: _target_: disent.dataset.sampling.GroundTruthSingleSampler diff --git a/experiment/config/sampling/_sampler_/gt__triplet.yaml b/experiment/config/sampling/_sampler_/gt__triplet.yaml new file mode 100644 index 00000000..09e9635d --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt__triplet.yaml @@ -0,0 +1,16 @@ +name: gt__triplet + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthTripleSampler + # factor sampling + p_k_range: ${sampling.k} + n_k_range: ${sampling.n_k} + n_k_sample_mode: ${sampling.n_k_mode} + n_k_is_shared: TRUE + # radius sampling + p_radius_range: ${sampling.k_radius} + n_radius_range: ${sampling.n_k_radius} + n_radius_sample_mode: ${sampling.n_k_radius_mode} + # final checks + swap_metric: ${sampling.swap_metric} + swap_chance: ${sampling.swap_chance} diff --git a/experiment/config/sampling/_sampler_/gt__weak_pair.yaml b/experiment/config/sampling/_sampler_/gt__weak_pair.yaml new file mode 100644 index 00000000..7561655b --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt__weak_pair.yaml @@ -0,0 +1,6 @@ +name: gt__weak_pair + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthPairOrigSampler + # factor sampling + p_k: ${sampling.k.1} diff --git a/experiment/config/sampling/_sampler_/gt_dist__pair.yaml b/experiment/config/sampling/_sampler_/gt_dist__pair.yaml new file mode 100644 index 00000000..5897a933 --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt_dist__pair.yaml @@ -0,0 +1,7 @@ +name: gt_dist__pair + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthDistSampler + num_samples: 2 + triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined + triplet_swap_chance: ${sampling.triplet_swap_chance} diff --git a/experiment/config/sampling/_sampler_/gt_dist__single.yaml b/experiment/config/sampling/_sampler_/gt_dist__single.yaml new file mode 100644 index 00000000..aeb46f27 --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt_dist__single.yaml @@ -0,0 +1,7 @@ +name: gt_dist__single + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthDistSampler + num_samples: 1 + triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined + triplet_swap_chance: ${sampling.triplet_swap_chance} diff --git a/experiment/config/sampling/_sampler_/gt_dist__triplet.yaml b/experiment/config/sampling/_sampler_/gt_dist__triplet.yaml new file mode 100644 index 00000000..cbcb6488 --- /dev/null +++ b/experiment/config/sampling/_sampler_/gt_dist__triplet.yaml @@ -0,0 +1,7 @@ +name: gt_dist__triplet + +sampler_cls: + _target_: disent.dataset.sampling.GroundTruthDistSampler + num_samples: 3 + triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined + triplet_swap_chance: ${sampling.triplet_swap_chance} diff --git a/experiment/config/dataset_sampler/gt_dist_weak_pair.yaml b/experiment/config/sampling/_sampler_/gt_dist__weak_pair.yaml similarity index 58% rename from experiment/config/dataset_sampler/gt_dist_weak_pair.yaml rename to experiment/config/sampling/_sampler_/gt_dist__weak_pair.yaml index 67992bf1..177e4240 100644 --- a/experiment/config/dataset_sampler/gt_dist_weak_pair.yaml +++ b/experiment/config/sampling/_sampler_/gt_dist__weak_pair.yaml @@ -1,10 +1,10 @@ -# @package _group_ -name: gt_dist_pair -sampler: +name: gt_dist__weak_pair + +sampler_cls: _target_: disent.dataset.sampling.GroundTruthDistSampler num_samples: 2 - triplet_sample_mode: ${dataset_sampling.triplet_sample_mode} # random, factors, manhattan, combined - triplet_swap_chance: ${dataset_sampling.triplet_swap_chance} + triplet_sample_mode: ${sampling.triplet_sample_mode} # random, factors, manhattan, combined + triplet_swap_chance: ${sampling.triplet_swap_chance} # ================================================== # # NOTE!!! THIS IS A DUMMY WRAPPER ,SO WE DON'T CRASH # diff --git a/experiment/config/dataset_sampler/random_pair.yaml b/experiment/config/sampling/_sampler_/random__pair.yaml similarity index 59% rename from experiment/config/dataset_sampler/random_pair.yaml rename to experiment/config/sampling/_sampler_/random__pair.yaml index 4eeb14d6..5170c299 100644 --- a/experiment/config/dataset_sampler/random_pair.yaml +++ b/experiment/config/sampling/_sampler_/random__pair.yaml @@ -1,5 +1,5 @@ -# @package _group_ -name: random_pair -sampler: +name: random__pair + +sampler_cls: _target_: disent.dataset.sampling.RandomSampler num_samples: 2 diff --git a/experiment/config/dataset_sampler/random_single.yaml b/experiment/config/sampling/_sampler_/random__single.yaml similarity index 58% rename from experiment/config/dataset_sampler/random_single.yaml rename to experiment/config/sampling/_sampler_/random__single.yaml index 9d86aef1..5a31e3d3 100644 --- a/experiment/config/dataset_sampler/random_single.yaml +++ b/experiment/config/sampling/_sampler_/random__single.yaml @@ -1,5 +1,5 @@ -# @package _group_ -name: random_single -sampler: +name: random__single + +sampler_cls: _target_: disent.dataset.sampling.RandomSampler num_samples: 1 diff --git a/experiment/config/dataset_sampler/random_triplet.yaml b/experiment/config/sampling/_sampler_/random__triplet.yaml similarity index 57% rename from experiment/config/dataset_sampler/random_triplet.yaml rename to experiment/config/sampling/_sampler_/random__triplet.yaml index 73ff64b0..7c9e9249 100644 --- a/experiment/config/dataset_sampler/random_triplet.yaml +++ b/experiment/config/sampling/_sampler_/random__triplet.yaml @@ -1,5 +1,5 @@ -# @package _group_ -name: random_triplet -sampler: +name: random__triplet + +sampler_cls: _target_: disent.dataset.sampling.RandomSampler num_samples: 3 diff --git a/experiment/config/dataset_sampler/random_weak_pair.yaml b/experiment/config/sampling/_sampler_/random__weak_pair.yaml similarity index 86% rename from experiment/config/dataset_sampler/random_weak_pair.yaml rename to experiment/config/sampling/_sampler_/random__weak_pair.yaml index f19d511c..3ff1d512 100644 --- a/experiment/config/dataset_sampler/random_weak_pair.yaml +++ b/experiment/config/sampling/_sampler_/random__weak_pair.yaml @@ -1,6 +1,6 @@ -# @package _group_ -name: random_pair -sampler: +name: random__weak_pair + +sampler_cls: _target_: disent.dataset.sampling.RandomSampler num_samples: 2 diff --git a/experiment/config/sampling/default.yaml b/experiment/config/sampling/default.yaml new file mode 100644 index 00000000..ae2f2920 --- /dev/null +++ b/experiment/config/sampling/default.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: choose the default from the framework and dataset +defaults: + - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} + +name: default + +# this config forces an error to be thrown if +# sampler config settings are required. diff --git a/experiment/config/sampling/default__bb.yaml b/experiment/config/sampling/default__bb.yaml new file mode 100644 index 00000000..0cdb2f65 --- /dev/null +++ b/experiment/config/sampling/default__bb.yaml @@ -0,0 +1,18 @@ +# SPECIALIZATION: choose the default from the framework and dataset +defaults: + - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} + +name: default__bb + +# varying factors (if applicable for pairs) -- sample in range: [min, max] +k: [0, -1] +k_radius: [0, -1] +# varying factors (if applicable for triplets) -- sample in range: [min, max] +n_k: [0, -1] +n_k_mode: 'bounded_below' +n_k_radius: [0, -1] +n_k_radius_mode: 'bounded_below' +# swap incorrect samples +swap_metric: NULL +# swap positive and negative if possible +swap_chance: NULL diff --git a/experiment/config/sampling/default__ran_l1.yaml b/experiment/config/sampling/default__ran_l1.yaml new file mode 100644 index 00000000..a74c04e1 --- /dev/null +++ b/experiment/config/sampling/default__ran_l1.yaml @@ -0,0 +1,18 @@ +# SPECIALIZATION: choose the default from the framework and dataset +defaults: + - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} + +name: default__ran_l1 + +# varying factors (if applicable for pairs) -- sample in range: [min, max] +k: [0, -1] +k_radius: [0, -1] +# varying factors (if applicable for triplets) -- sample in range: [min, max] +n_k: [0, -1] +n_k_mode: 'random' +n_k_radius: [0, -1] +n_k_radius_mode: 'random' +# swap incorrect samples +swap_metric: 'manhattan' +# swap positive and negative if possible +swap_chance: NULL diff --git a/experiment/config/sampling/default__ran_l2.yaml b/experiment/config/sampling/default__ran_l2.yaml new file mode 100644 index 00000000..acdb7194 --- /dev/null +++ b/experiment/config/sampling/default__ran_l2.yaml @@ -0,0 +1,18 @@ +# SPECIALIZATION: choose the default from the framework and dataset +defaults: + - _sampler_: ${dataset/_data_type_}__${framework/_input_mode_} + +name: default__ran_l2 + +# varying factors (if applicable for pairs) -- sample in range: [min, max] +k: [0, -1] +k_radius: [0, -1] +# varying factors (if applicable for triplets) -- sample in range: [min, max] +n_k: [0, -1] +n_k_mode: 'random' +n_k_radius: [0, -1] +n_k_radius_mode: 'random' +# swap incorrect samples +swap_metric: 'euclidean' +# swap positive and negative if possible +swap_chance: NULL diff --git a/experiment/config/sampling/gt_dist__combined.yaml b/experiment/config/sampling/gt_dist__combined.yaml new file mode 100644 index 00000000..d467433c --- /dev/null +++ b/experiment/config/sampling/gt_dist__combined.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__combined + +triplet_sample_mode: "combined" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/gt_dist__combined_scaled.yaml b/experiment/config/sampling/gt_dist__combined_scaled.yaml new file mode 100644 index 00000000..410beeb4 --- /dev/null +++ b/experiment/config/sampling/gt_dist__combined_scaled.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__combined_scaled + +triplet_sample_mode: "combined_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/gt_dist__factors.yaml b/experiment/config/sampling/gt_dist__factors.yaml new file mode 100644 index 00000000..c4a1001c --- /dev/null +++ b/experiment/config/sampling/gt_dist__factors.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__factors + +triplet_sample_mode: "factors" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/gt_dist__manhat.yaml b/experiment/config/sampling/gt_dist__manhat.yaml new file mode 100644 index 00000000..fcbad2ae --- /dev/null +++ b/experiment/config/sampling/gt_dist__manhat.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__manhat + +triplet_sample_mode: "manhattan" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/gt_dist__manhat_scaled.yaml b/experiment/config/sampling/gt_dist__manhat_scaled.yaml new file mode 100644 index 00000000..5fb96993 --- /dev/null +++ b/experiment/config/sampling/gt_dist__manhat_scaled.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__manhat_scaled + +triplet_sample_mode: "manhattan_scaled" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/gt_dist__random.yaml b/experiment/config/sampling/gt_dist__random.yaml new file mode 100644 index 00000000..2079f2c4 --- /dev/null +++ b/experiment/config/sampling/gt_dist__random.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force gt_dist to be used as the dataset sampler, this might not be supported by some datasets +defaults: + - _sampler_: gt_dist__${framework/_input_mode_} + +name: gt_dist__random + +triplet_sample_mode: "random" # random, factors, manhattan, manhattan_scaled, combined, combined_scaled +triplet_swap_chance: 0 diff --git a/experiment/config/sampling/none.yaml b/experiment/config/sampling/none.yaml new file mode 100644 index 00000000..b0f60fbd --- /dev/null +++ b/experiment/config/sampling/none.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force the user to choose a different sampling strategy +defaults: + - _sampler_: "${exit:EXITING... please specify in the defaults list a sampling method other than none}" + +name: none + +# this config forces an error to be thrown! This is to make +# sure that we don't encounter errors when updating old configs. diff --git a/experiment/config/sampling/random.yaml b/experiment/config/sampling/random.yaml new file mode 100644 index 00000000..cf2d7922 --- /dev/null +++ b/experiment/config/sampling/random.yaml @@ -0,0 +1,8 @@ +# SPECIALIZATION: force the random strategy to be used as the dataset sampler +defaults: + - _sampler_: random__${framework/_input_mode_} + +name: random + +# this config forces an error to be thrown if +# sampler config settings are required. diff --git a/experiment/config/schedule/adavae_down_all.yaml b/experiment/config/schedule/adavae_down_all.yaml index 62adb3f6..8d4ac7f5 100644 --- a/experiment/config/schedule/adavae_down_all.yaml +++ b/experiment/config/schedule/adavae_down_all.yaml @@ -1,27 +1,27 @@ -# @package _global_ -schedules_name: adavae_down_all -schedules: +name: averaging_decrease__all + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # ada triplet + r_end: 0.0 # triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # loss active + r_end: 0.0 # loss inactive adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.5 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.5 # ada weighted triplet + r_end: 1.0 # normal triplet ada_thresh_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # all averaged, should this not be 0.5 the recommended value + r_end: 0.0 # none averaged diff --git a/experiment/config/schedule/adavae_down_ratio.yaml b/experiment/config/schedule/adavae_down_ratio.yaml index 8a61ad45..e3816fe1 100644 --- a/experiment/config/schedule/adavae_down_ratio.yaml +++ b/experiment/config/schedule/adavae_down_ratio.yaml @@ -1,21 +1,21 @@ -# @package _global_ -schedules_name: adavae_down_ratio -schedules: +name: averaging_decrease__ratio + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # ada triplet + r_end: 0.0 # triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # loss active + r_end: 0.0 # loss inactive adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.5 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.5 # ada weighted triplet + r_end: 1.0 # normal triplet diff --git a/experiment/config/schedule/adavae_down_thresh.yaml b/experiment/config/schedule/adavae_down_thresh.yaml index 36a9a9bb..0ca660ac 100644 --- a/experiment/config/schedule/adavae_down_thresh.yaml +++ b/experiment/config/schedule/adavae_down_thresh.yaml @@ -1,9 +1,9 @@ -# @package _global_ -schedules_name: adavae_down_thresh -schedules: +name: averaging_decrease__thresh + +schedule_items: ada_thresh_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # all averaged, should this not be 0.5 the recommended value + r_end: 0.0 # none averaged diff --git a/experiment/config/schedule/adavae_up_all.yaml b/experiment/config/schedule/adavae_up_all.yaml index c3213679..58cdd98d 100644 --- a/experiment/config/schedule/adavae_up_all.yaml +++ b/experiment/config/schedule/adavae_up_all.yaml @@ -1,27 +1,27 @@ -# @package _global_ -schedules_name: adavae_up_all -schedules: +name: averaging_increase__all + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # triplet + r_end: 1.0 # ada triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # loss inactive + r_end: 1.0 # loss active adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.5 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # normal triplet + r_end: 0.5 # ada weighted triplet ada_thresh_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # none averaged + r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/config/schedule/adavae_up_all_full.yaml b/experiment/config/schedule/adavae_up_all_full.yaml index b9d393f1..471a8da4 100644 --- a/experiment/config/schedule/adavae_up_all_full.yaml +++ b/experiment/config/schedule/adavae_up_all_full.yaml @@ -1,27 +1,27 @@ -# @package _global_ -schedules_name: adavae_up_all -schedules: +name: averaging_increase__all_full + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # triplet + r_end: 1.0 # ada triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # loss inactive + r_end: 1.0 # loss active adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # normal triplet + r_end: 0.0 # ada weighted triplet ada_thresh_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # none averaged + r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/config/schedule/adavae_up_ratio.yaml b/experiment/config/schedule/adavae_up_ratio.yaml index fdead36d..79e3e6c3 100644 --- a/experiment/config/schedule/adavae_up_ratio.yaml +++ b/experiment/config/schedule/adavae_up_ratio.yaml @@ -1,21 +1,21 @@ -# @package _global_ -schedules_name: adavae_up_ratio -schedules: +name: averaging_increase__ratio + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # triplet + r_end: 1.0 # ada triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # loss inactive + r_end: 1.0 # loss active adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.5 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # normal triplet + r_end: 0.5 # ada weighted triplet diff --git a/experiment/config/schedule/adavae_up_ratio_full.yaml b/experiment/config/schedule/adavae_up_ratio_full.yaml index 1aedef24..5b7fabe3 100644 --- a/experiment/config/schedule/adavae_up_ratio_full.yaml +++ b/experiment/config/schedule/adavae_up_ratio_full.yaml @@ -1,21 +1,21 @@ -# @package _global_ -schedules_name: adavae_up_ratio -schedules: +name: averaging_increase__ratio_full + +schedule_items: adat_triplet_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # triplet + r_end: 1.0 # ada triplet adat_triplet_soft_scale: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # loss inactive + r_end: 1.0 # loss active adat_triplet_share_scale: # reversed compared to others _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 1.0 - r_end: 0.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 # normal triplet + r_end: 0.0 # ada weighted triplet diff --git a/experiment/config/schedule/adavae_up_thresh.yaml b/experiment/config/schedule/adavae_up_thresh.yaml index d60eba88..7c012a32 100644 --- a/experiment/config/schedule/adavae_up_thresh.yaml +++ b/experiment/config/schedule/adavae_up_thresh.yaml @@ -1,9 +1,9 @@ -# @package _global_ -schedules_name: adavae_up_thresh -schedules: +name: averaging_increase__thresh + +schedule_items: ada_thresh_ratio: _target_: disent.schedule.LinearSchedule - min_step: 0 - max_step: ${trainer.steps} - r_start: 0.0 - r_end: 1.0 + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.0 # none averaged + r_end: 1.0 # all averaged, should this not be 0.5 the recommended value diff --git a/experiment/config/schedule/beta_cosine_wave.yaml b/experiment/config/schedule/beta_cosine_wave.yaml deleted file mode 100644 index 7a82976f..00000000 --- a/experiment/config/schedule/beta_cosine_wave.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# @package _global_ -schedules_name: cosine_wave -schedules: - beta: - # skip the first cycle - _target_: disent.schedule.Clip - min_step: 3600 - max_step: NULL - shift_step: TRUE - min_value: NULL - max_value: NULL - # nested schedule - schedule: - _target_: disent.schedule.CosineWave - period: 7200 - r_start: 0.001 - r_end: 1.0 diff --git a/experiment/config/schedule/beta_cyclic.yaml b/experiment/config/schedule/beta_cyclic.yaml index d309010c..636bc50c 100644 --- a/experiment/config/schedule/beta_cyclic.yaml +++ b/experiment/config/schedule/beta_cyclic.yaml @@ -1,20 +1,12 @@ -# @package _global_ -schedules_name: beta_cyclic -schedules: +name: beta_cyclic + +schedule_items: beta: - # skip the first cycle - _target_: disent.schedule.Clip - min_step: 3600 - max_step: NULL - shift_step: TRUE - min_value: NULL - max_value: NULL - # nested schedule - schedule: - _target_: disent.schedule.Cyclic - period: 7200 - repeats: NULL - r_start: 0.001 - r_end: 1.0 - end_value: 'end' # start/end -- only used if repeats is set - mode: 'cosine' + _target_: disent.schedule.Cyclic + period: 7200 + start_step: 3600 + repeats: NULL + r_start: 0.001 + r_end: 1.0 + end_mode: 'end' + mode: 'cosine' diff --git a/experiment/config/schedule/beta_cyclic_fast.yaml b/experiment/config/schedule/beta_cyclic_fast.yaml new file mode 100644 index 00000000..5b64fc94 --- /dev/null +++ b/experiment/config/schedule/beta_cyclic_fast.yaml @@ -0,0 +1,12 @@ +name: beta_cyclic + +schedule_items: + beta: + _target_: disent.schedule.Cyclic + period: 3600 + start_step: 3600 + repeats: NULL + r_start: 0.001 + r_end: 1.0 + end_mode: 'end' + mode: 'cosine' diff --git a/experiment/config/schedule/beta_cyclic_slow.yaml b/experiment/config/schedule/beta_cyclic_slow.yaml new file mode 100644 index 00000000..de6f0333 --- /dev/null +++ b/experiment/config/schedule/beta_cyclic_slow.yaml @@ -0,0 +1,12 @@ +name: beta_cyclic + +schedule_items: + beta: + _target_: disent.schedule.Cyclic + period: 14400 + start_step: 3600 + repeats: NULL + r_start: 0.001 + r_end: 1.0 + end_mode: 'end' + mode: 'cosine' diff --git a/experiment/config/schedule/beta_decrease.yaml b/experiment/config/schedule/beta_decrease.yaml new file mode 100644 index 00000000..b9d3cfac --- /dev/null +++ b/experiment/config/schedule/beta_decrease.yaml @@ -0,0 +1,10 @@ +name: beta_decrease + +schedule_items: + beta: + _target_: disent.schedule.Single + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 1.0 + r_end: 0.001 + mode: 'linear' diff --git a/experiment/config/schedule/beta_increase.yaml b/experiment/config/schedule/beta_increase.yaml index 04515d63..d5031ba0 100644 --- a/experiment/config/schedule/beta_increase.yaml +++ b/experiment/config/schedule/beta_increase.yaml @@ -1,20 +1,10 @@ -# @package _global_ -schedules_name: beta_increase -schedules: +name: beta_increase + +schedule_items: beta: - # skip the first cycle - _target_: disent.schedule.Clip - min_step: 3600 - max_step: NULL - shift_step: TRUE - min_value: NULL - max_value: NULL - # nested schedule - schedule: - _target_: disent.schedule.Cyclic - period: 14400 - repeats: 1 - r_start: 0.001 - r_end: 1.0 - end_value: 'start' # start/end -- only used if repeats is set - mode: 'linear' + _target_: disent.schedule.Single + start_step: 0 + end_step: ${trainer.max_steps} + r_start: 0.001 + r_end: 1.0 + mode: 'linear' diff --git a/experiment/config/schedule/none.yaml b/experiment/config/schedule/none.yaml index 904fea77..5a314b0d 100644 --- a/experiment/config/schedule/none.yaml +++ b/experiment/config/schedule/none.yaml @@ -1,3 +1,3 @@ -# @package _global_ -schedules_name: none -schedules: +name: none + +schedule_items: {} diff --git a/experiment/run.py b/experiment/run.py index 003f04df..865091d3 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -24,7 +24,6 @@ import logging import os -import sys from datetime import datetime import hydra @@ -33,6 +32,7 @@ import torch.utils.data import wandb from omegaconf import DictConfig +from omegaconf import ListConfig from omegaconf import OmegaConf from pytorch_lightning.loggers import CometLogger from pytorch_lightning.loggers import LoggerCollection @@ -46,12 +46,11 @@ from disent.util.seeds import seed from disent.util.strings.fmt import make_box_str from disent.util.lightning.callbacks import LoggerProgressCallback -from disent.util.lightning.callbacks import VaeDisentanglementLoggingCallback +from disent.util.lightning.callbacks import VaeMetricLoggingCallback from disent.util.lightning.callbacks import VaeLatentCycleLoggingCallback +from disent.util.lightning.callbacks import VaeGtDistsLoggingCallback from experiment.util.hydra_data import HydraDataModule from experiment.util.hydra_utils import make_non_strict -from experiment.util.hydra_utils import merge_specializations -from experiment.util.hydra_utils import instantiate_recursive from experiment.util.run_utils import log_error_and_exit from experiment.util.run_utils import safe_unset_debug_logger from experiment.util.run_utils import safe_unset_debug_trainer @@ -68,27 +67,31 @@ def hydra_check_cuda(cfg): + cuda = cfg.dsettings.trainer.cuda # set cuda - if cfg.trainer.cuda in {'try_cuda', None}: - cfg.trainer.cuda = torch.cuda.is_available() - if not cfg.trainer.cuda: + if cuda in {'try_cuda', None}: + cfg.dsettings.trainer.cuda = torch.cuda.is_available() + if not cuda: log.warning('CUDA was requested, but not found on this system... CUDA has been disabled!') else: if not torch.cuda.is_available(): - if cfg.trainer.cuda: + if cuda: log.error('trainer.cuda=True but CUDA is not available on this machine!') raise RuntimeError('CUDA not available!') else: log.warning('CUDA is not available on this machine!') else: - if not cfg.trainer.cuda: + if not cuda: log.warning('CUDA is available but is not being used!') -def hydra_check_datadir(prepare_data_per_node, cfg): - if not os.path.isabs(cfg.dataset.data_root): +def hydra_check_data_paths(cfg): + prepare_data_per_node = cfg.trainer.prepare_data_per_node + data_root = cfg.dsettings.storage.data_root + # check relative paths + if not os.path.isabs(data_root): log.warning( - f'A relative path was specified for dataset.data_root={repr(cfg.dataset.data_root)}.' + f'A relative path was specified for dsettings.storage.data_root={repr(data_root)}.' f' This is probably an error! Using relative paths can have unintended consequences' f' and performance drawbacks if the current working directory is on a shared/network drive.' f' Hydra config also uses a new working directory for each run of the program, meaning' @@ -96,139 +99,156 @@ def hydra_check_datadir(prepare_data_per_node, cfg): ) if prepare_data_per_node: log.error( - f'trainer.prepare_data_per_node={repr(prepare_data_per_node)} but dataset.data_root=' - f'{repr(cfg.dataset.data_root)} is a relative path which may be an error! Try specifying an' - f' absolute path that is guaranteed to be unique from each node, eg. dataset.data_root=/tmp/dataset' + f'trainer.prepare_data_per_node={repr(prepare_data_per_node)} but dsettings.storage.data_root=' + f'{repr(data_root)} is a relative path which may be an error! Try specifying an' + f' absolute path that is guaranteed to be unique from each node, eg. default_settings.storage.data_root=/tmp/dataset' ) - raise RuntimeError(f'dataset.data_root={repr(cfg.dataset.data_root)} is a relative path!') + raise RuntimeError(f'default_settings.storage.data_root={repr(data_root)} is a relative path!') def hydra_make_logger(cfg): - loggers = [] - - # initialise logging dict - cfg.setdefault('logging', {}) - - if ('wandb' in cfg.logging) and cfg.logging.wandb.setdefault('enabled', True): - loggers.append(WandbLogger( - offline=cfg.logging.wandb.setdefault('offline', False), - entity=cfg.logging.wandb.setdefault('entity', None), # cometml: workspace - project=cfg.logging.wandb.project, # cometml: project_name - name=cfg.logging.wandb.name, # cometml: experiment_name - group=cfg.logging.wandb.setdefault('group', None), # experiment group - tags=cfg.logging.wandb.setdefault('tags', None), # experiment tags - save_dir=hydra.utils.to_absolute_path(cfg.logging.logs_dir), # relative to hydra's original cwd - )) - else: - cfg.logging.setdefault('wandb', dict(enabled=False)) - - if ('cometml' in cfg.logging) and cfg.logging.cometml.setdefault('enabled', True): - loggers.append(CometLogger( - offline=cfg.logging.cometml.setdefault('offline', False), - workspace=cfg.logging.cometml.setdefault('workspace', None), # wandb: entity - project_name=cfg.logging.cometml.project, # wandb: project - experiment_name=cfg.logging.cometml.name, # wandb: name - api_key=os.environ['COMET_API_KEY'], # TODO: use dotenv - save_dir=hydra.utils.to_absolute_path(cfg.logging.logs_dir), # relative to hydra's original cwd - )) + # make wandb logger + backend = cfg.logging.wandb + if backend.enabled: + log.info('Initialising Weights & Biases Logger') + return WandbLogger( + offline=backend.offline, + entity=backend.entity, # cometml: workspace + project=backend.project, # cometml: project_name + name=backend.name, # cometml: experiment_name + group=backend.group, # experiment group + tags=backend.tags, # experiment tags + save_dir=hydra.utils.to_absolute_path(cfg.dsettings.storage.logs_dir), # relative to hydra's original cwd + ) + # don't return a logger + return None # LoggerCollection([...]) OR DummyLogger(...) + + +def _callback_make_progress(cfg, callback_cfg): + return LoggerProgressCallback( + interval=callback_cfg.interval + ) + + +def _callback_make_latent_cycle(cfg, callback_cfg): + if cfg.logging.wandb.enabled: + # checks + if not (('vis_min' in cfg.dataset and 'vis_max' in cfg.dataset) or ('vis_mean' in cfg.dataset and 'vis_std' in cfg.dataset)): + log.warning('dataset does not have visualisation ranges specified, set `vis_min` & `vis_max` OR `vis_mean` & `vis_std`') + # this currently only supports WANDB logger + return VaeLatentCycleLoggingCallback( + seed = callback_cfg.seed, + every_n_steps = callback_cfg.every_n_steps, + begin_first_step = callback_cfg.begin_first_step, + mode = callback_cfg.mode, + # recon_min = cfg.data.meta.vis_min, + # recon_max = cfg.data.meta.vis_max, + recon_mean = cfg.dataset.meta.vis_mean, + recon_std = cfg.dataset.meta.vis_std, + ) else: - cfg.logging.setdefault('cometml', dict(enabled=False)) - - # TODO: maybe return DummyLogger instead? - return LoggerCollection(loggers) if loggers else None # lists are turned into a LoggerCollection by pl - - -def hydra_append_progress_callback(callbacks, cfg): - if 'progress' in cfg.callbacks: - callbacks.append(LoggerProgressCallback( - interval=cfg.callbacks.progress.interval - )) - - -def hydra_append_latent_cycle_logger_callback(callbacks, cfg): - if 'latent_cycle' in cfg.callbacks: - if cfg.logging.wandb.enabled: - # this currently only supports WANDB logger - callbacks.append(VaeLatentCycleLoggingCallback( - seed=cfg.callbacks.latent_cycle.seed, - every_n_steps=cfg.callbacks.latent_cycle.every_n_steps, - begin_first_step=False, - mode=cfg.callbacks.latent_cycle.mode, - recon_min=cfg.dataset.setdefault('vis_min', 0.), - recon_max=cfg.dataset.setdefault('vis_max', 1.), - )) + log.warning('latent_cycle callback is not being used because wandb is not enabled!') + return None + + +def _callback_make_gt_dists(cfg, callback_cfg): + return VaeGtDistsLoggingCallback( + seed = callback_cfg.seed, + every_n_steps = callback_cfg.every_n_steps, + traversal_repeats = callback_cfg.traversal_repeats, + begin_first_step = callback_cfg.begin_first_step, + plt_block_size = 1.25, + plt_show = False, + plt_transpose = False, + log_wandb = True, + batch_size = cfg.settings.dataset.batch_size, + include_factor_dists = True, + ) + + +_CALLBACK_MAKERS = { + 'progress': _callback_make_progress, + 'latent_cycle': _callback_make_latent_cycle, + 'gt_dists': _callback_make_gt_dists, +} + + +def hydra_get_callbacks(cfg) -> list: + callbacks = [] + # add all callbacks + for name, item in cfg.callbacks.items(): + # custom callback handling vs instantiation + if '_target_' in item: + name = f'{name} ({item._target_})' + callback = hydra.utils.instantiate(item) else: - log.warning('latent_cycle callback is not being used because wandb is not enabled!') + callback = _CALLBACK_MAKERS[name](cfg, item) + # add to callbacks list + if callback is not None: + log.info(f'made callback: {name}') + callbacks.append(callback) + else: + log.info(f'skipped callback: {name}') + return callbacks -def hydra_append_metric_callback(callbacks, cfg): +def hydra_get_metric_callbacks(cfg) -> list: + callbacks = [] # set default values used later - default_every_n_steps = cfg.metrics.setdefault('default_every_n_steps', 3600) - default_on_final = cfg.metrics.setdefault('default_on_final', True) - default_on_train = cfg.metrics.setdefault('default_on_train', True) + default_every_n_steps = cfg.metrics.default_every_n_steps + default_on_final = cfg.metrics.default_on_final + default_on_train = cfg.metrics.default_on_train + default_begin_first_step = cfg.metrics.default_begin_first_step # get metrics - metric_list = cfg.metrics.setdefault('metric_list', []) - if metric_list == 'all': - cfg.metrics.metric_list = metric_list = [{k: {}} for k in metrics.DEFAULT_METRICS] + metric_list = cfg.metrics.metric_list + assert isinstance(metric_list, (list, ListConfig)), f'`metrics.metric_list` is not a list, got: {type(metric_list)}' # get metrics - new_metrics_list = [] - for i, metric in enumerate(metric_list): + for metric in metric_list: + assert isinstance(metric, (dict, DictConfig)), f'entry in metric list is not a dictionary, got type: {type(metric)} or value: {repr(metric)}' # fix the values if isinstance(metric, str): metric = {metric: {}} ((name, settings),) = metric.items() - if settings is None: - settings = {} - new_metrics_list.append({name: settings}) - # get metrics - every_n_steps = settings.get('every_n_steps', default_every_n_steps) - train_metric = [metrics.FAST_METRICS[name]] if settings.get('on_train', default_on_train) else None + # check values + assert isinstance(metric, (dict, DictConfig)), f'settings for entry in metric list is not a dictionary, got type: {type(settings)} or value: {repr(settings)}' + # make metrics + train_metric = [metrics.FAST_METRICS[name]] if settings.get('on_train', default_on_train) else None final_metric = [metrics.DEFAULT_METRICS[name]] if settings.get('on_final', default_on_final) else None # add the metric callback if final_metric or train_metric: - callbacks.append(VaeDisentanglementLoggingCallback( - every_n_steps=every_n_steps, - step_end_metrics=train_metric, - train_end_metrics=final_metric, + callbacks.append(VaeMetricLoggingCallback( + step_end_metrics = train_metric, + train_end_metrics = final_metric, + every_n_steps = settings.get('every_n_steps', default_every_n_steps), + begin_first_step = settings.get('begin_first_step', default_begin_first_step), )) - cfg.metrics.metric_list = new_metrics_list - - -def hydra_append_correlation_callback(callbacks, cfg): - if 'correlation' in cfg.callbacks: - log.warning('Correlation callback has been disabled. skipping!') - # callbacks.append(VaeLatentCorrelationLoggingCallback( - # repeats_per_factor=cfg.callbacks.correlation.repeats_per_factor, - # every_n_steps=cfg.callbacks.correlation.every_n_steps, - # begin_first_step=False, - # )) + return callbacks def hydra_register_schedules(module: DisentFramework, cfg): - if cfg.schedules is None: - cfg.schedules = {} - if cfg.schedules: + # check the type + schedule_items = cfg.schedule.schedule_items + assert isinstance(schedule_items, (dict, DictConfig)), f'`schedule.schedule_items` must be a dictionary, got type: {type(schedule_items)} with value: {repr(schedule_items)}' + # add items + if schedule_items: log.info(f'Registering Schedules:') - for target, schedule in cfg.schedules.items(): - module.register_schedule(target, instantiate_recursive(schedule), logging=True) + for target, schedule in schedule_items.items(): + module.register_schedule(target, hydra.utils.instantiate(schedule), logging=True) -def hydra_create_framework_config(cfg): +def hydra_create_and_update_framework_config(cfg) -> DisentConfigurable.cfg: # create framework config - this is also kinda hacky # - we need instantiate_recursive because of optimizer_kwargs, # otherwise the dictionary is left as an OmegaConf dict - framework_cfg: DisentConfigurable.cfg = instantiate_recursive({ - **cfg.framework.module, - **dict(_target_=cfg.framework.module._target_ + '.cfg') - }) + framework_cfg: DisentConfigurable.cfg = hydra.utils.instantiate(cfg.framework.cfg) # warn if some of the cfg variables were not overridden - missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.module.keys()))) + missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.cfg.keys()))) if missing_keys: log.error(f'Framework {repr(cfg.framework.name)} is missing config keys for:') for k in missing_keys: log.error(f'{repr(k)}') # update config params in case we missed variables in the cfg - cfg.framework.module.update(framework_cfg.to_dict()) + cfg.framework.cfg.update(framework_cfg.to_dict()) # return config return framework_cfg @@ -236,29 +256,57 @@ def hydra_create_framework_config(cfg): def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cfg): # specific handling for experiment, this is HACKY! # - not supported normally, we need to instantiate to get the class (is there hydra support for this?) - framework_cfg.optimizer = hydra.utils.instantiate(dict(_target_=framework_cfg.optimizer), [torch.Tensor()]).__class__ + framework_cfg.optimizer = hydra.utils.get_class(framework_cfg.optimizer) framework_cfg.optimizer_kwargs = dict(framework_cfg.optimizer_kwargs) - # instantiate - return hydra.utils.instantiate( - dict(_target_=cfg.framework.module._target_), - model=init_model_weights( - AutoEncoder( - encoder=hydra.utils.instantiate(cfg.model.encoder), - decoder=hydra.utils.instantiate(cfg.model.decoder) - ), mode=cfg.model.weight_init - ), - # apply augmentations to batch on GPU which can be faster than via the dataloader - batch_augment=datamodule.batch_augment, - cfg=framework_cfg + # get framework path + assert str.endswith(cfg.framework.cfg._target_, '.cfg'), f'`cfg.framework.cfg._target_` does not end with ".cfg", got: {repr(cfg.framework.cfg._target_)}' + framework_cls = hydra.utils.get_class(cfg.framework.cfg._target_[:-len(".cfg")]) + # create model + model = AutoEncoder( + encoder=hydra.utils.instantiate(cfg.model.encoder_cls), + decoder=hydra.utils.instantiate(cfg.model.decoder_cls), + ) + # initialise the model + model = init_model_weights(model, mode=cfg.settings.model.weight_init) + # create framework + return framework_cls( + model=model, + cfg=framework_cfg, + batch_augment=datamodule.batch_augment, # apply augmentations to batch on GPU which can be faster than via the dataloader ) # ========================================================================= # -# RUNNER # +# ACTIONS # # ========================================================================= # -def run(cfg: DictConfig, config_path: str = None): +def prepare_data(cfg: DictConfig, config_path: str = None): + # get the time the run started + time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') + log.info(f'Starting run at time: {time_string}') + raise NotImplementedError + + # # allow the cfg to be edited + # cfg = make_non_strict(cfg) + # # deterministic seed + # seed(cfg.job.setdefault('seed', None)) + # # print useful info + # log.info(f"Current working directory : {os.getcwd()}") + # log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") + # # hydra config does not support variables in defaults lists, we handle this manually + # cfg = merge_specializations(cfg, config_path=CONFIG_PATH if (config_path is None) else config_path, required=['_dataset_sampler_']) + # # check data preparation + # prepare_data_per_node = cfg.trainer.setdefault('prepare_data_per_node', True) + # hydra_check_datadir(prepare_data_per_node, cfg) + # # print the config + # log.info(f'Dataset Config Is:\n{make_box_str(OmegaConf.to_yaml({"dataset": cfg.dataset}))}') + # # prepare data + # datamodule = HydraDataModule(cfg) + # datamodule.prepare_data() + + +def train(cfg: DictConfig, config_path: str = None): # get the time the run started time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') @@ -280,7 +328,7 @@ def run(cfg: DictConfig, config_path: str = None): cfg = make_non_strict(cfg) # deterministic seed - seed(cfg.job.setdefault('seed', None)) + seed(cfg.settings.job.seed) # -~-~-~-~-~-~-~-~-~-~-~-~- # # INITIALISE & SETDEFAULT IN CONFIG @@ -293,26 +341,21 @@ def run(cfg: DictConfig, config_path: str = None): log.info(f"Current working directory : {os.getcwd()}") log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") - # hydra config does not support variables in defaults lists, we handle this manually - cfg = merge_specializations(cfg, config_path=CONFIG_PATH if (config_path is None) else config_path) - # check CUDA setting - cfg.trainer.setdefault('cuda', 'try_cuda') hydra_check_cuda(cfg) + # check data preparation - prepare_data_per_node = cfg.trainer.setdefault('prepare_data_per_node', True) - hydra_check_datadir(prepare_data_per_node, cfg) + hydra_check_data_paths(cfg) # TRAINER CALLBACKS - callbacks = [] - hydra_append_progress_callback(callbacks, cfg) - hydra_append_latent_cycle_logger_callback(callbacks, cfg) - hydra_append_metric_callback(callbacks, cfg) - hydra_append_correlation_callback(callbacks, cfg) + callbacks = [ + *hydra_get_callbacks(cfg), + *hydra_get_metric_callbacks(cfg), + ] # HYDRA MODULES datamodule = HydraDataModule(cfg) - framework_cfg = hydra_create_framework_config(cfg) + framework_cfg = hydra_create_and_update_framework_config(cfg) framework = hydra_create_framework(framework_cfg, datamodule, cfg) # register schedules @@ -320,30 +363,35 @@ def run(cfg: DictConfig, config_path: str = None): # Setup Trainer trainer = set_debug_trainer(pl.Trainer( - log_every_n_steps=cfg.logging.setdefault('log_every_n_steps', 50), - flush_logs_every_n_steps=cfg.logging.setdefault('flush_logs_every_n_steps', 100), logger=logger, callbacks=callbacks, - gpus=1 if cfg.trainer.cuda else 0, - max_epochs=cfg.trainer.setdefault('epochs', 100), - max_steps=cfg.trainer.setdefault('steps', None), - prepare_data_per_node=prepare_data_per_node, - progress_bar_refresh_rate=0, # ptl 0.9 - terminate_on_nan=True, # we do this here so we don't run the final metrics + gpus=1 if cfg.dsettings.trainer.cuda else 0, + # we do this here too so we don't run the final + # metrics, even through we check for it manually. + terminate_on_nan=True, # TODO: re-enable this in future... something is not compatible # with saving/checkpointing models + allow enabling from the # config. Seems like something cannot be pickled? checkpoint_callback=False, + # additional trainer kwargs + **cfg.trainer, )) # -~-~-~-~-~-~-~-~-~-~-~-~- # # BEGIN TRAINING # -~-~-~-~-~-~-~-~-~-~-~-~- # - # print the config - log.info(f'Final Config Is:\n{make_box_str(OmegaConf.to_yaml(cfg))}') - - # save hparams TODO: I think this is a pytorch lightning bug... The trainer should automatically save these if hparams is set. + # get config sections + print_cfg, boxed_pop = dict(cfg), lambda *keys: make_box_str(OmegaConf.to_yaml({k: print_cfg.pop(k) for k in keys} if keys else print_cfg)) + cfg_str_logging = boxed_pop('logging', 'callbacks', 'metrics') + cfg_str_dataset = boxed_pop('dataset', 'sampling', 'augment') + cfg_str_system = boxed_pop('framework', 'model', 'schedule') + cfg_str_settings = boxed_pop('dsettings', 'settings') + cfg_str_other = boxed_pop() + # print config sections + log.info(f'Final Config For Action: {cfg.action}\n\nLOGGING:{cfg_str_logging}\nDATASET:{cfg_str_dataset}\nSYSTEM:{cfg_str_system}\nTRAINER:{cfg_str_other}\nSETTINGS:{cfg_str_settings}') + + # save hparams TODO: is this a pytorch lightning bug? The trainer should automatically save these if hparams is set? framework.hparams.update(cfg) if trainer.logger: trainer.logger.log_hyperparams(framework.hparams) @@ -354,6 +402,13 @@ def run(cfg: DictConfig, config_path: str = None): trainer.fit(framework, datamodule=datamodule) +# available actions +ACTIONS = { + 'prepare_data': prepare_data, + 'train': train, +} + + # ========================================================================= # # MAIN # # ========================================================================= # @@ -367,19 +422,40 @@ def run(cfg: DictConfig, config_path: str = None): if __name__ == '__main__': + # register a custom OmegaConf resolver that allows us to put in a ${exit:msg} that exits the program + # - if we don't register this, the program will still fail because we have an unknown + # resolver. This just prettifies the output. + class ConfigurationError(Exception): + pass + + def _error_resolver(msg: str): + raise ConfigurationError(msg) + + OmegaConf.register_new_resolver('exit', _error_resolver) + @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) def hydra_main(cfg: DictConfig): try: - run(cfg) + action_key = cfg.action + # get the action + if action_key not in ACTIONS: + raise KeyError(f'The given action: {repr(action_key)} is invalid, must be one of: {sorted(ACTIONS.keys())}') + action = ACTIONS[action_key] + # run the action + action(cfg) except Exception as e: - log_error_and_exit(err_type='experiment error', err_msg=str(e)) + log_error_and_exit(err_type='experiment error', err_msg=str(e), exc_info=True) + except: + log_error_and_exit(err_type='experiment error', err_msg='', exc_info=True) try: hydra_main() except KeyboardInterrupt as e: log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) except Exception as e: - log_error_and_exit(err_type='hydra error', err_msg=str(e)) + log_error_and_exit(err_type='hydra error', err_msg=str(e), exc_info=True) + except: + log_error_and_exit(err_type='hydra error', err_msg='', exc_info=True) # ========================================================================= # diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index 0b387b6b..daa4c191 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -23,14 +23,15 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +import warnings + import hydra import torch.utils.data import pytorch_lightning as pl from omegaconf import DictConfig from disent.dataset import DisentDataset -from disent.nn.transform import DisentDatasetTransform -from experiment.util.hydra_utils import instantiate_recursive +from disent.dataset.transform import DisentDatasetTransform log = logging.getLogger(__name__) @@ -88,17 +89,18 @@ def __init__(self, hparams: DictConfig): else: self.hparams.update(hparams) # transform: prepares data from datasets - self.data_transform = instantiate_recursive(self.hparams.dataset.transform) + self.data_transform = hydra.utils.instantiate(self.hparams.dataset.transform) assert (self.data_transform is None) or callable(self.data_transform) # input_transform_aug: augment data for inputs, then apply input_transform - self.input_transform = instantiate_recursive(self.hparams.augment.transform) + self.input_transform = hydra.utils.instantiate(self.hparams.augment.augment_cls) assert (self.input_transform is None) or callable(self.input_transform) # batch_augment: augments transformed data for inputs, should be applied across a batch # which version of the dataset we need to use if GPU augmentation is enabled or not. # - corresponds to below in train_dataloader() - if self.hparams.dataset.gpu_augment: + if self.hparams.dsettings.dataset.gpu_augment: # TODO: this is outdated! self.batch_augment = DisentDatasetTransform(transform=self.input_transform) + warnings.warn('`gpu_augment=True` is outdated and may no longer be equivalent to `gpu_augment=False`') else: self.batch_augment = None # datasets initialised in setup() @@ -117,16 +119,16 @@ def prepare_data(self) -> None: # things could go wrong. We try be efficient about it by removing the # in_memory argument if it exists. log.info(f'Data - Preparation & Downloading') - instantiate_recursive(data) + hydra.utils.instantiate(data) def setup(self, stage=None) -> None: # ground truth data log.info(f'Data - Instance') - data = instantiate_recursive(self.hparams.dataset.data) + data = hydra.utils.instantiate(self.hparams.dataset.data) # Wrap the data for the framework some datasets need triplets, pairs, etc. # Augmentation is done inside the frameworks so that it can be done on the GPU, otherwise things are very slow. - self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.dataset_sampler.sampler), transform=self.data_transform, augment=None) - self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.dataset_sampler.sampler), transform=self.data_transform, augment=self.input_transform) + self.dataset_train_noaug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=None) + self.dataset_train_aug = DisentDataset(data, hydra.utils.instantiate(self.hparams.sampling._sampler_.sampler_cls), transform=self.data_transform, augment=self.input_transform) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Training Dataset: @@ -147,17 +149,21 @@ def train_dataloader(self): """ # Select which version of the dataset we need to use if GPU augmentation is enabled or not. # - corresponds to above in __init__() - if self.hparams.dataset.gpu_augment: + if self.hparams.dsettings.dataset.gpu_augment: dataset = self.dataset_train_noaug else: dataset = self.dataset_train_aug - # create dataloader - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=self.hparams.dataset.batch_size, - num_workers=self.hparams.dataset.num_workers, - shuffle=True, + # get default kwargs + default_kwargs = { + 'shuffle': True, # This should usually be TRUE if cuda is enabled. # About 20% faster with the xysquares dataset, RTX 2060 Rev. A, and Intel i7-3930K - pin_memory=self.hparams.dataset.pin_memory, - ) + 'pin_memory': self.hparams.dsettings.trainer.cuda + } + # get config kwargs + kwargs = self.hparams.dataloader + # check required keys + if ('batch_size' not in kwargs) or ('num_workers' not in kwargs): + raise KeyError(f'`dataset.dataloader` must contain keys: ["batch_size", "num_workers"], got: {sorted(kwargs.keys())}') + # create dataloader + return torch.utils.data.DataLoader(dataset=dataset, **{**default_kwargs, **kwargs}) diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 79258cf5..87e633e8 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -24,6 +24,8 @@ import logging from copy import deepcopy +from typing import Optional +from typing import Sequence import hydra from omegaconf import DictConfig @@ -37,87 +39,18 @@ # ========================================================================= # -# Recursive Hydra Instantiation # -# TODO: use https://github.com/facebookresearch/hydra/pull/989 # -# I think this is quicker? Just doesn't perform checks... # +# Helper # # ========================================================================= # -@deprecated('replace with hydra 1.1') -def call_recursive(config): - # recurse - def _call_recursive(config): - if isinstance(config, (dict, DictConfig)): - c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'} - if '_target_' in config: - config = hydra.utils.instantiate({'_target_': config['_target_']}, **c) - elif isinstance(config, (tuple, list, ListConfig)): - config = [_call_recursive(v) for v in config] - return config - return _call_recursive(config) - - -# alias -@deprecated('replace with hydra 1.1') -def instantiate_recursive(config): - return call_recursive(config) - - -@deprecated('replace with hydra 1.1') -def instantiate_object_if_needed(config_or_object): - if isinstance(config_or_object, dict): - return instantiate_recursive(config_or_object) - else: - return config_or_object - - -# ========================================================================= # -# Better Specializations # -# TODO: this might be replaced by recursive instantiation # -# https://github.com/facebookresearch/hydra/pull/1044 # -# ========================================================================= # - - -@deprecated('replace with hydra 1.1') def make_non_strict(cfg: DictConfig): + """ + Convert the config into a mutable version. + """ cfg = deepcopy(cfg) return OmegaConf.create({**cfg}) -@deprecated('replace with hydra 1.1') -def merge_specializations(cfg: DictConfig, config_path: str, strict=True): - import os - - # TODO: this should eventually be replaced with hydra recursive defaults - # TODO: this makes config non-strict, allows setdefault to work even if key does not exist in config - - assert os.path.isabs(config_path), f'config_path cannot be relative for merge_specializations: {repr(config_path)}, current working directory: {repr(os.getcwd())}' - - # skip if we do not have any specializations - if 'specializations' not in cfg: - log.warning('`specializations` key not found in `cfg`, skipping merging specializations') - return - - # we allow overwrites & missing values to be inserted - if not strict: - cfg = make_non_strict(cfg) - - # set and update specializations - for group, specialization in cfg.specializations.items(): - assert group not in cfg, f'group={repr(group)} already exists on cfg, specialization merging is not supported!' - log.info(f'merging specialization: {repr(specialization)}') - # load specialization config - specialization_cfg = OmegaConf.load(os.path.join(config_path, group, f'{specialization}.yaml')) - # create new config - cfg = OmegaConf.merge(cfg, {group: specialization_cfg}) - - # remove specializations key - del cfg['specializations'] - - # done - return cfg - - # ========================================================================= # # END # # ========================================================================= # diff --git a/requirements-experiment.txt b/requirements-experiment.txt index 85ef3e77..01c2db1f 100644 --- a/requirements-experiment.txt +++ b/requirements-experiment.txt @@ -24,9 +24,10 @@ wandb>=0.10.32 # UTILITY # ======= -hydra-core==1.0.7 # includes omegaconf -hydra-colorlog==1.0.1 -hydra-submitit-launcher==1.1.1 +omegaconf>=2.1.0 # only 2.1.0 supports nested variable interpolation eg. ${group.${group.key}} +hydra-core==1.1.1 # needs omegaconf +hydra-colorlog==1.1.0 +hydra-submitit-launcher==1.1.6 # MISSING DEPS - these are imported or referened (_target_) in /experiments, but not included here OR in requirements.txt diff --git a/setup.py b/setup.py index db5f2926..87deb7c0 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.2.1", + version="0.3.0", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(), diff --git a/tests/test_data.py b/tests/test_data.py index 1d315f77..17e836e3 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -48,7 +48,7 @@ # factors=(3, 3, 2, 3), len=54 -TestXYObjectData = wrapped_partial(XYObjectData, grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') +TestXYObjectData = wrapped_partial(XYObjectData, grid_size=4, grid_spacing=1, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb_1') _TEST_LEN = 54 diff --git a/tests/test_data_similarity.py b/tests/test_data_similarity.py new file mode 100644 index 00000000..33169471 --- /dev/null +++ b/tests/test_data_similarity.py @@ -0,0 +1,64 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import numpy as np + +from disent.dataset.data import XYObjectData +from disent.dataset.data import XYObjectShadedData + + +# ========================================================================= # +# TESTS # +# ========================================================================= # + + + + +def test_xyobject_similarity(): + for palette in XYObjectData.COLOR_PALETTES_3.keys(): + # create + data0 = XYObjectData(palette=palette) + data1 = XYObjectShadedData(palette=palette) + assert len(data0) == len(data1) + assert data0.factor_sizes == (*data1.factor_sizes[:-2], np.prod(data1.factor_sizes[-2:])) + # check random + for i in np.random.randint(len(data0), size=100): + assert np.allclose(data0[i], data1[i]) + + +def test_xyobject_grey_similarity(): + for palette in XYObjectData.COLOR_PALETTES_1.keys(): + # create + data0 = XYObjectData(palette=palette, rgb=False) + data1 = XYObjectShadedData(palette=palette, rgb=False) + assert len(data0) == len(data1) + assert data0.factor_sizes == (*data1.factor_sizes[:-2], np.prod(data1.factor_sizes[-2:])) + # check random + for i in np.random.randint(len(data0), size=100): + assert np.allclose(data0[i], data1[i]) + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 31d98169..00b72f8e 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -41,7 +41,7 @@ def test_experiment_run(): os.environ['HYDRA_FULL_ERROR'] = '1' with temp_sys_args([experiment_run.__file__]): # why does this not work when config is absolute? - hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.run) + hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.train) hydra_main() diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index f7f7c9cf..91c1750c 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -37,9 +37,9 @@ from disent.frameworks.ae import * from disent.frameworks.vae import * from disent.model import AutoEncoder -from disent.model.ae import DecoderTest -from disent.model.ae import EncoderTest -from disent.nn.transform import ToStandardisedTensor +from disent.model.ae import DecoderLinear +from disent.model.ae import EncoderLinear +from disent.dataset.transform import ToImgTensorF32 # ========================================================================= # @@ -66,6 +66,7 @@ # VAE - weakly supervised (AdaVae, dict(), XYObjectData), (AdaVae, dict(ada_average_mode='ml-vae'), XYObjectData), + (AdaGVaeMinimal, dict(), XYObjectData), # VAE - supervised (TripletVae, dict(), XYObjectData), (TripletVae, dict(disable_decoder=True, disable_reg_loss=True, disable_posterior_scale=0.5), XYObjectData), @@ -78,13 +79,13 @@ def test_frameworks(Framework, cfg_kwargs, Data): }[Framework.REQUIRED_OBS] data = XYObjectData() if (Data is None) else Data() - dataset = DisentDataset(data, DataSampler(), transform=ToStandardisedTensor()) + dataset = DisentDataset(data, DataSampler(), transform=ToImgTensorF32()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) framework = Framework( model=AutoEncoder( - encoder=EncoderTest(x_shape=data.x_shape, z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1), - decoder=DecoderTest(x_shape=data.x_shape, z_size=6), + encoder=EncoderLinear(x_shape=data.x_shape, z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1), + decoder=DecoderLinear(x_shape=data.x_shape, z_size=6), ), cfg=Framework.cfg(**cfg_kwargs) ) diff --git a/tests/test_math.py b/tests/test_math.py index 1921c18a..69c2e4cd 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -41,7 +41,7 @@ from disent.nn.functional import torch_idct from disent.nn.functional import torch_idct2 from disent.nn.functional import torch_mean_generalized -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.util import to_numpy @@ -132,9 +132,9 @@ def test_dct(): def test_fft_conv2d(): data = XYObjectData() - dataset = DisentDataset(data, RandomSampler(), transform=ToStandardisedTensor(), augment=None) + dataset = DisentDataset(data, RandomSampler(), transform=ToImgTensorF32(), augment=None) # sample data - factors = dataset.ground_truth_data.sample_random_factor_traversal(f_idx=2) + factors = dataset.gt_data.sample_random_factor_traversal(f_idx=2) batch = dataset.dataset_batch_from_factors(factors=factors, mode="input") # test torch_conv2d_channel_wise variants for i in range(1, 5): diff --git a/tests/test_math_generic.py b/tests/test_math_generic.py index 9f692b05..77fa42ca 100644 --- a/tests/test_math_generic.py +++ b/tests/test_math_generic.py @@ -26,11 +26,11 @@ import pytest import torch -from disent.nn.functional._generic_tensors import generic_as_int32 -from disent.nn.functional._generic_tensors import generic_max -from disent.nn.functional._generic_tensors import generic_min -from disent.nn.functional._generic_tensors import generic_ndim -from disent.nn.functional._generic_tensors import generic_shape +from disent.nn.functional._util_generic import generic_as_int32 +from disent.nn.functional._util_generic import generic_max +from disent.nn.functional._util_generic import generic_min +from disent.nn.functional._util_generic import generic_ndim +from disent.nn.functional._util_generic import generic_shape # ========================================================================= # diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5945fd40..92b0a9bc 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -28,7 +28,7 @@ from disent.dataset.data import XYObjectData from disent.dataset import DisentDataset from disent.metrics import * -from disent.nn.transform import ToStandardisedTensor +from disent.dataset.transform import ToImgTensorF32 from disent.util.function import wrapped_partial @@ -48,7 +48,7 @@ def test_metrics(metric_fn): z_size = 8 # ground truth data # TODO: DisentDataset should not be needed to compute metrics! - dataset = DisentDataset(XYObjectData(), transform=ToStandardisedTensor()) + dataset = DisentDataset(XYObjectData(), transform=ToImgTensorF32()) # randomly sampled representation get_repr = lambda x: torch.randn(len(x), z_size) # evaluate diff --git a/tests/test_models.py b/tests/test_models.py index b1f2e1ce..7155d732 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -34,7 +34,7 @@ [EncoderConv64, DecoderConv64], [EncoderConv64Norm, DecoderConv64Norm], [EncoderFC, DecoderFC], - [EncoderTest, DecoderTest], + [EncoderLinear, DecoderLinear], ]) def test_ae_models(encoder_cls: DisentEncoder, decoder_cls: DisentDecoder): x_shape, z_size = (3, 64, 64), 8 diff --git a/tests/test_registry.py b/tests/test_registry.py index edb1c686..c173d969 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -23,7 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from disent.registry import REGISTRY +from disent.registry import REGISTRIES # ========================================================================= # @@ -32,15 +32,15 @@ COUNTS = { - 'dataset': 6, - 'sampler': 8, - 'framework': 10, - 'recon_loss': 6, - 'latent_dist': 2, - 'optimizer': 38, - 'metric': 5, - 'schedule': 5, - 'model': 8, + 'DATASETS': 6, + 'SAMPLERS': 8, + 'FRAMEWORKS': 10, + 'RECON_LOSSES': 6, + 'LATENT_DISTS': 2, + 'OPTIMIZERS': 30, + 'METRICS': 5, + 'SCHEDULES': 5, + 'MODELS': 8, } @@ -49,14 +49,14 @@ def test_registry_loading(): # load everything and check the counts total = 0 - for registry in REGISTRY: + for registry in REGISTRIES: count = 0 - for name in REGISTRY[registry]: - loaded = REGISTRY[registry][name] + for name in REGISTRIES[registry]: + loaded = REGISTRIES[registry][name] count += 1 total += 1 - assert COUNTS[registry] == count - assert total == sum(COUNTS.values()) + assert COUNTS[registry] == count, f'invalid count for: {registry}' + assert total == sum(COUNTS.values()), f'invalid total' # ========================================================================= # diff --git a/tests/test_transform.py b/tests/test_transform.py index b6a4ced2..a21ca1e2 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -24,8 +24,8 @@ import pytest import torch -from disent.nn.transform import FftGaussianBlur -from disent.nn.transform._augment import _expand_to_min_max_tuples +from disent.dataset.transform import FftGaussianBlur +from disent.dataset.transform._augment import _expand_to_min_max_tuples from disent.nn.functional import torch_gaussian_kernel from disent.nn.functional import torch_gaussian_kernel_2d