diff --git a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy1x1.pt b/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy1x1.pt deleted file mode 100644 index a5c2fdac..00000000 Binary files a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy1x1.pt and /dev/null differ diff --git a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy8x8.pt b/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy8x8.pt deleted file mode 100644 index faf1f852..00000000 Binary files a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0001_xy8x8.pt and /dev/null differ diff --git a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt b/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt deleted file mode 100644 index 15aee5c6..00000000 Binary files a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt and /dev/null differ diff --git a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt b/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt deleted file mode 100644 index 2d529163..00000000 Binary files a/data/adversarial_kernel/r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt and /dev/null differ diff --git a/disent/dataset/data/__init__.py b/disent/dataset/data/__init__.py index 178aa183..1a870327 100644 --- a/disent/dataset/data/__init__.py +++ b/disent/dataset/data/__init__.py @@ -44,15 +44,10 @@ # groundtruth -- impl from disent.dataset.data._groundtruth__cars3d import Cars3dData from disent.dataset.data._groundtruth__dsprites import DSpritesData -from disent.dataset.data._groundtruth__dsprites_imagenet import DSpritesImagenetData # pragma: delete-on-release from disent.dataset.data._groundtruth__mpi3d import Mpi3dData from disent.dataset.data._groundtruth__norb import SmallNorbData from disent.dataset.data._groundtruth__shapes3d import Shapes3dData # groundtruth -- impl synthetic -from disent.dataset.data._groundtruth__xyblocks import XYBlocksData # pragma: delete-on-release from disent.dataset.data._groundtruth__xyobject import XYObjectData from disent.dataset.data._groundtruth__xyobject import XYObjectShadedData -from disent.dataset.data._groundtruth__xysquares import XYSquaresData # pragma: delete-on-release -from disent.dataset.data._groundtruth__xysquares import XYSquaresMinimalData # pragma: delete-on-release -from disent.dataset.data._groundtruth__xcolumns import XColumnsData # pragma: delete-on-release diff --git a/disent/dataset/data/_groundtruth__xcolumns.py b/disent/dataset/data/_groundtruth__xcolumns.py deleted file mode 100644 index b50503ce..00000000 --- a/disent/dataset/data/_groundtruth__xcolumns.py +++ /dev/null @@ -1,66 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from typing import Tuple - -import numpy as np - -from disent.dataset.data._groundtruth__xysquares import XYSquaresData - - -# ========================================================================= # -# xy multi grid data # -# ========================================================================= # - - -class XColumnsData(XYSquaresData): - - name = 'x_columns' - - @property - def factor_names(self) -> Tuple[str, ...]: - return ('x_R', 'x_G', 'x_B')[:self._num_squares] - - @property - def factor_sizes(self) -> Tuple[int, ...]: - return (self._placements,) * self._num_squares - - def _get_observation(self, idx): - # get factors - factors = self.idx_to_pos(idx) - offset, space, size = self._offset, self._spacing, self._square_size - # GENERATE - obs = np.zeros(self.img_shape, dtype=self._dtype) - for i, fx in enumerate(factors): - x = offset + space * fx - if self._rgb: - obs[:, x:x+size, i] = self._fill_value - else: - obs[:, x:x+size, :] = self._fill_value - return obs - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/dataset/data/_groundtruth__xyblocks.py b/disent/dataset/data/_groundtruth__xyblocks.py deleted file mode 100644 index efeae411..00000000 --- a/disent/dataset/data/_groundtruth__xyblocks.py +++ /dev/null @@ -1,160 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from typing import Tuple - -import numpy as np - -from disent.dataset.data._groundtruth import GroundTruthData - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# xy squares data # -# ========================================================================= # - - -class XYBlocksData(GroundTruthData): - - """ - Dataset that generates all possible permutations of xor'd squares of - different scales moving across the grid. - - This dataset is designed not to overlap in the reconstruction loss space, but xor'ing may be too - complex to learn efficiently, and some sizes of factors may be too small (eg. biggest - square moving only has two positions) - """ - - COLOR_PALETTES_1 = { - 'white': [ - [255], - ], - 'greys_halves': [ - [128], - [255], - ], - 'greys_quarters': [ - [64], - [128], - [192], - [255], - ], - # alias for white, so that we can just set `rgb=False` - 'rgb': [ - [255], - ] - } - - COLOR_PALETTES_3 = { - 'white': [ - [255, 255, 255], - ], - # THIS IS IDEAL. - 'rgb': [ - [255, 000, 000], - [000, 255, 000], - [000, 000, 255], - ], - 'colors': [ - [255, 000, 000], [000, 255, 000], [000, 000, 255], - [255, 255, 000], [000, 255, 255], [255, 000, 255], - [255, 255, 255], - ], - } - - @property - def factor_names(self) -> Tuple[str, ...]: - return self._factor_names - - @property - def factor_sizes(self) -> Tuple[int, ...]: - return self._factor_sizes - - @property - def img_shape(self) -> Tuple[int, ...]: - return self._img_shape - - def __init__( - self, - grid_size: int = 64, - grid_levels: Tuple[int, ...] = (1, 2, 3), - rgb: bool = True, - palette: str = 'rgb', - invert_bg: bool = False, - transform=None, - ): - # colors - self._rgb = rgb - if palette != 'rgb': - log.warning('rgb palette is not being used, might overlap for the reconstruction loss.') - if rgb: - assert palette in XYBlocksData.COLOR_PALETTES_3, f'{palette=} must be one of {list(XYBlocksData.COLOR_PALETTES_3.keys())}' - self._colors = np.array(XYBlocksData.COLOR_PALETTES_3[palette]) - else: - assert palette in XYBlocksData.COLOR_PALETTES_1, f'{palette=} must be one of {list(XYBlocksData.COLOR_PALETTES_1.keys())}' - self._colors = np.array(XYBlocksData.COLOR_PALETTES_1[palette]) - - # bg colors - self._bg_color = 255 if invert_bg else 0 # we dont need rgb for this - assert not np.any([np.all(self._bg_color == color) for color in self._colors]), f'Color conflict with background: {self._bg_color} ({invert_bg=}) in {self._colors}' - - # grid - grid_levels = np.arange(1, grid_levels+1) if isinstance(grid_levels, int) else np.array(grid_levels) - assert np.all(grid_size % (2 ** grid_levels) == 0), f'{grid_size=} is not divisible by pow(2, {grid_levels=})' - assert np.all(grid_levels[:-1] <= grid_levels[1:]) - self._grid_size = grid_size - self._grid_levels = grid_levels - self._grid_dims = len(grid_levels) - - # axis sizes - self._axis_divisions = 2 ** self._grid_levels - assert len(self._axis_divisions) == self._grid_dims and np.all(grid_size % self._axis_divisions) == 0, 'This should never happen' - self._axis_division_sizes = grid_size // self._axis_divisions - - # info - self._factor_names = tuple([f'{prefix}-{d}' for prefix in ['color', 'x', 'y'] for d in self._axis_divisions]) - self._factor_sizes = tuple([len(self._colors)] * self._grid_dims + list(self._axis_divisions) * 2) - self._img_shape = (grid_size, grid_size, 3 if self._rgb else 1) - - # initialise - super().__init__(transform=transform) - - def _get_observation(self, idx): - positions = self.idx_to_pos(idx) - cs, xs, ys = positions[:self._grid_dims*1], positions[self._grid_dims*1:self._grid_dims*2], positions[self._grid_dims*2:] - assert len(xs) == len(ys) == len(cs) - # GENERATE - obs = np.full(self.img_shape, self._bg_color, dtype=np.uint8) - for i, (x, y, s, c) in enumerate(zip(xs, ys, self._axis_division_sizes, cs)): - obs[y*s:(y+1)*s, x*s:(x+1)*s, :] = self._colors[c] if np.any(obs[y*s, x*s, :] != self._colors[c]) else self._bg_color - # RETURN - return obs - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/dataset/data/_groundtruth__xysquares.py b/disent/dataset/data/_groundtruth__xysquares.py deleted file mode 100644 index 01c7e4d6..00000000 --- a/disent/dataset/data/_groundtruth__xysquares.py +++ /dev/null @@ -1,202 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from typing import Optional -from typing import Tuple -from typing import Union - -import numpy as np - -from disent.dataset.data._groundtruth import GroundTruthData -from disent.util.iters import iter_chunks - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# xy multi grid data # -# ========================================================================= # - - -class XYSquaresMinimalData(GroundTruthData): - """ - Dataset that generates all possible permutations of 3 (R, G, B) coloured - squares placed on a square grid. This dataset is designed to not overlap - in the reconstruction loss space. - - If you use this in your work, please cite: https://github.com/nmichlo/disent - - NOTE: Unlike XYSquaresData, XYSquaresMinimalData is the bare-minimum class - to generate the same results as the default values for XYSquaresData, - this class is a fair bit faster (~0.8x)! - - All 3 squares are returned, in RGB, each square is size 8, with - non-overlapping grid spacing set to 8 pixels, in total leaving - 8*8*8*8*8*8 factors. Images are uint8 with fill values 0 (bg) - and 255 (fg). - """ - - name = 'xy_squares_minimal' - - @property - def factor_names(self) -> Tuple[str, ...]: - return 'x_R', 'y_R', 'x_G', 'y_G', 'x_B', 'y_B' - - @property - def factor_sizes(self) -> Tuple[int, ...]: - return 8, 8, 8, 8, 8, 8 # R, G, B squares - - @property - def img_shape(self) -> Tuple[int, ...]: - return 64, 64, 3 - - def _get_observation(self, idx): - # get factors - factors = np.reshape(np.unravel_index(idx, self.factor_sizes), (-1, 2)) - # GENERATE - obs = np.zeros(self.img_shape, dtype=np.uint8) - for i, (fx, fy) in enumerate(factors): - x, y = 8 * fx, 8 * fy - obs[y:y+8, x:x+8, i] = 255 - return obs - - -# ========================================================================= # -# xy multi grid data # -# ========================================================================= # - - -class XYSquaresData(GroundTruthData): - - """ - Dataset that generates all possible permutations of 3 (R, G, B) coloured - squares placed on a square grid. This dataset is designed to not overlap - in the reconstruction loss space. (if the spacing is set correctly.) - - If you use this in your work, please cite: https://github.com/nmichlo/disent - - NOTE: Unlike XYSquaresMinimalData, XYSquaresData allows adjusting various aspects - of the data that is generated, but the generation process is slower (~1.25x). - """ - - name = 'xy_squares' - - @property - def factor_names(self) -> Tuple[str, ...]: - return ('x_R', 'y_R', 'x_G', 'y_G', 'x_B', 'y_B')[:self._num_squares*2] - - @property - def factor_sizes(self) -> Tuple[int, ...]: - return (self._placements, self._placements) * self._num_squares # R, G, B squares - - @property - def img_shape(self) -> Tuple[int, ...]: - return self._width, self._width, (3 if self._rgb else 1) - - def __init__( - self, - square_size: int = 8, - image_size: int = 64, - grid_size: Optional[int] = None, - grid_spacing: Optional[int] = None, - num_squares: int = 3, - rgb: bool = True, - fill_value: Optional[Union[float, int]] = None, - dtype: Union[np.dtype, str] = np.uint8, - no_warnings: bool = False, - transform=None, - ): - """ - :param square_size: the size of the individual squares in pixels - :param image_size: the image size in pixels - :param grid_spacing: the step size between square positions on the grid. By - default this is set to square_size which results in non-overlapping - data if `grid_spacing >= square_size` Reducing this value such that - `grid_spacing < square_size` results in overlapping data. - :param num_squares: The number of squares drawn. `1 <= num_squares <= 3` - :param rgb: Image has 3 channels if True, otherwise it is greyscale with 1 channel. - :param no_warnings: If warnings should be disabled if overlapping. - :param fill_value: The foreground value to use for filling squares, the default background value is 0. - :param grid_size: The number of grid positions available for the square to be placed in. The square is centered if this is less than - :param dtype: - """ - if grid_spacing is None: - grid_spacing = square_size - if (grid_spacing < square_size) and not no_warnings: - log.warning(f'overlap between squares for reconstruction loss, {grid_spacing} < {square_size}') - # color - self._rgb = rgb - self._dtype = np.dtype(dtype) - # check fill values - if self._dtype.kind == 'u': - self._fill_value = 255 if (fill_value is None) else fill_value - assert isinstance(self._fill_value, int) - assert 0 < self._fill_value <= 255, f'0 < {self._fill_value} <= 255' - elif self._dtype.kind == 'f': - self._fill_value = 1.0 if (fill_value is None) else fill_value - assert isinstance(self._fill_value, (int, float)) - assert 0.0 < self._fill_value <= 1.0, f'0.0 < {self._fill_value} <= 1.0' - else: - raise TypeError(f'invalid dtype: {self._dtype}, must be float or unsigned integer') - # image sizes - self._width = image_size - # number of squares - self._num_squares = num_squares - assert 1 <= num_squares <= 3, 'Only 1, 2 or 3 squares are supported!' - # square scales - self._square_size = square_size - # x, y - self._spacing = grid_spacing - self._placements = (self._width - self._square_size) // grid_spacing + 1 - # maximum placements - if grid_size is not None: - assert isinstance(grid_size, int) - assert grid_size > 0 - if (grid_size > self._placements) and not no_warnings: - log.warning(f'number of possible placements: {self._placements} is less than the given grid size: {grid_size}, reduced grid size from: {grid_size} -> {self._placements}') - self._placements = min(self._placements, grid_size) - # center elements - self._offset = (self._width - (self._square_size + (self._placements-1)*self._spacing)) // 2 - # initialise parents -- they depend on self.factors - super().__init__(transform=transform) - - def _get_observation(self, idx): - # get factors - factors = self.idx_to_pos(idx) - offset, space, size = self._offset, self._spacing, self._square_size - # GENERATE - obs = np.zeros(self.img_shape, dtype=self._dtype) - for i, (fx, fy) in enumerate(iter_chunks(factors, 2)): - x, y = offset + space * fx, offset + space * fy - if self._rgb: - obs[y:y+size, x:x+size, i] = self._fill_value - else: - obs[y:y+size, x:x+size, :] = self._fill_value - return obs - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/dataset/transform/_augment.py b/disent/dataset/transform/_augment.py index 22dfcb72..69082a1d 100644 --- a/disent/dataset/transform/_augment.py +++ b/disent/dataset/transform/_augment.py @@ -237,8 +237,6 @@ def _check_kernel(kernel: torch.Tensor) -> torch.Tensor: # (REGEX, EXAMPLE, FACTORY_FUNC) # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first. - (re.compile(r'^(xy8)_r(47)$'), 'xy8_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, '../../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy8x8.pt')))), # pragma: delete-on-release - (re.compile(r'^(xy1)_r(47)$'), 'xy1_r47', lambda kern, radius: torch.load(os.path.abspath(os.path.join(disent.__file__, '../../../data/adversarial_kernel', 'r47-1_s28800_adam_lr0.003_wd0.0_xy1x1.pt')))), # pragma: delete-on-release (re.compile(r'^(box)_r(\d+)$'), 'box_r31', lambda kern, radius: torch_box_kernel_2d(radius=int(radius))[None, ...]), (re.compile(r'^(gau)_r(\d+)$'), 'gau_r31', lambda kern, radius: torch_gaussian_kernel_2d(sigma=int(radius) / 4.0, truncate=4.0)[None, None, ...]), ] diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py index 6c14422c..08be5c04 100644 --- a/disent/dataset/util/stats.py +++ b/disent/dataset/util/stats.py @@ -97,22 +97,8 @@ def main(progress=False): data.SmallNorbData, data.Shapes3dData, # groundtruth -- impl synthetic - data.XYBlocksData, # pragma: delete-on-release data.XYObjectData, data.XYObjectShadedData, - data.XYSquaresData, # pragma: delete-on-release - data.XYSquaresMinimalData, # pragma: delete-on-release - data.XColumnsData, # pragma: delete-on-release - # groundtruth -- increasing overlap # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=8)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=7)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=6)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=5)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=4)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=3)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=2)), # pragma: delete-on-release - (data.XYSquaresData, dict(grid_size=8, grid_spacing=1)), # pragma: delete-on-release - (data.XYSquaresData, dict(rgb=False)), # pragma: delete-on-release # large datasets (data.Mpi3dData, dict(subset='toy', in_memory=True)), (data.Mpi3dData, dict(subset='realistic', in_memory=True)), diff --git a/disent/frameworks/ae/experimental/__init__.py b/disent/frameworks/ae/experimental/__init__.py deleted file mode 100644 index 6055d82a..00000000 --- a/disent/frameworks/ae/experimental/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -# supervised frameworks -from disent.frameworks.ae.experimental._supervised__adaneg_tae import AdaNegTripletAe - -# unsupervised frameworks -from disent.frameworks.ae.experimental._unsupervised__dotae import DataOverlapTripletAe - -# weakly supervised frameworks -from disent.frameworks.ae.experimental._weaklysupervised__adaae import AdaAe diff --git a/disent/frameworks/ae/experimental/_supervised__adaneg_tae.py b/disent/frameworks/ae/experimental/_supervised__adaneg_tae.py deleted file mode 100644 index 647adda6..00000000 --- a/disent/frameworks/ae/experimental/_supervised__adaneg_tae.py +++ /dev/null @@ -1,71 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from dataclasses import dataclass -from numbers import Number -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple -from typing import Union - -import torch - -from disent.frameworks.ae._supervised__tae import TripletAe -from disent.frameworks.ae.experimental._weaklysupervised__adaae import AdaAe -from disent.frameworks.vae.experimental._supervised__adaneg_tvae import AdaNegTripletVae - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -class AdaNegTripletAe(TripletAe): - """ - This is a condensed version of the ada_tvae and adaave_tvae, - using approximately the best settings and loss... - """ - - REQUIRED_OBS = 3 - - @dataclass - class cfg(TripletAe.cfg, AdaAe.cfg): - # ada_tvae - loss - adat_triplet_share_scale: float = 0.95 - - def hook_ae_compute_ave_aug_loss(self, zs: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[Union[torch.Tensor, Number], Dict[str, Any]]: - return AdaNegTripletVae.estimate_ada_triplet_loss_from_zs( - zs=zs, - cfg=self.cfg, - ) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/ae/experimental/_unsupervised__dotae.py b/disent/frameworks/ae/experimental/_unsupervised__dotae.py deleted file mode 100644 index 569e29f2..00000000 --- a/disent/frameworks/ae/experimental/_unsupervised__dotae.py +++ /dev/null @@ -1,76 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple - -import torch - -from disent.frameworks.ae.experimental._supervised__adaneg_tae import AdaNegTripletAe -from disent.frameworks.vae.experimental._supervised__adaneg_tvae import AdaNegTripletVae -from disent.frameworks.vae.experimental._unsupervised__dotvae import DataOverlapMixin - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Data Overlap Triplet AE # -# ========================================================================= # - - -class DataOverlapTripletAe(AdaNegTripletAe, DataOverlapMixin): - - REQUIRED_OBS = 1 - - @dataclass - class cfg(AdaNegTripletAe.cfg, DataOverlapMixin.cfg): - pass - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # initialise mixin - self.init_data_overlap_mixin() - - def hook_ae_compute_ave_aug_loss(self, zs: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, Any]]: - [z], [x_targ_orig] = zs, xs_targ - # 1. randomly generate and mine triplets using augmented versions of the inputs - a_idxs, p_idxs, n_idxs = self.random_mined_triplets(x_targ_orig=x_targ_orig) - # 2. compute triplet loss - loss, loss_log = AdaNegTripletVae.estimate_ada_triplet_loss_from_zs( - zs=[z[idxs] for idxs in (a_idxs, p_idxs, n_idxs)], - cfg=self.cfg, - ) - return loss, { - **loss_log, - } - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/ae/experimental/_weaklysupervised__adaae.py b/disent/frameworks/ae/experimental/_weaklysupervised__adaae.py deleted file mode 100644 index f38690a5..00000000 --- a/disent/frameworks/ae/experimental/_weaklysupervised__adaae.py +++ /dev/null @@ -1,81 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple - -import torch -from dataclasses import dataclass - -from disent.frameworks.ae._unsupervised__ae import Ae -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae - - -# ========================================================================= # -# Ada-GVAE # -# ========================================================================= # - - -class AdaAe(Ae): - """ - Custom implementation, removing Variational Auto-Encoder components of: - Weakly Supervised Disentanglement Learning Without Compromises: https://arxiv.org/abs/2002.02886 - - MODIFICATION: - - L1 distance for deltas instead of KL divergence - - adjustable threshold value - """ - - REQUIRED_OBS = 2 - - @dataclass - class cfg(Ae.cfg): - ada_thresh_ratio: float = 0.5 - - def hook_ae_intercept_zs(self, zs: Sequence[torch.Tensor]) -> Tuple[Sequence[torch.Tensor], Dict[str, Any]]: - """ - Adaptive VAE Glue Method, putting the various components together - 1. find differences between deltas - 2. estimate a threshold for differences - 3. compute a shared mask from this threshold - 4. average together elements that should be considered shared - - TODO: the methods used in this function should probably be moved here - TODO: this function could be turned into a torch.nn.Module! - """ - z0, z1 = zs - # shared elements that need to be averaged, computed per pair in the batch. - share_mask = AdaVae.compute_shared_mask_from_zs(z0, z1, ratio=self.cfg.ada_thresh_ratio) - # compute average posteriors - new_zs = AdaVae.make_shared_zs(z0, z1, share_mask) - # return new args & generate logs - return new_zs, { - 'shared': share_mask.sum(dim=1).float().mean() - } - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 18b31cee..723d764f 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -281,8 +281,6 @@ def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_t # (REGEX, EXAMPLE, FACTORY_FUNC) # - factory function takes at min one arg: fn(reduction) with one arg after that per regex capture group # - regex expressions are tested in order, expressions should be mutually exclusive or ordered such that more specialized versions occur first. - (re.compile(r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_w(\d+\.\d+)$'), 'mse_xy8_r47_w1.0', lambda reduction, loss, kern, weight: AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=1-float(weight), aug_weight=float(weight))), # pragma: delete-on-release - (re.compile(r'^([a-z\d]+)_([a-z\d]+_[a-z\d]+)_l(\d+\.\d+)_k(\d+\.\d+)$'), 'mse_xy8_r47_l1.0_k1.0', lambda reduction, loss, kern, l_weight, k_weight: AugmentedReconLossHandler(make_reconstruction_loss(loss, reduction=reduction), kernel=kern, wrap_weight=float(l_weight), aug_weight=float(k_weight))), # pragma: delete-on-release ] diff --git a/disent/frameworks/vae/experimental/__init__.py b/disent/frameworks/vae/experimental/__init__.py deleted file mode 100644 index cb2006fd..00000000 --- a/disent/frameworks/vae/experimental/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -# supervised frameworks -from disent.frameworks.vae.experimental._supervised__adaave_tvae import AdaAveTripletVae -from disent.frameworks.vae.experimental._supervised__adaneg_tvae import AdaNegTripletVae -from disent.frameworks.vae.experimental._supervised__adatvae import AdaTripletVae -from disent.frameworks.vae.experimental._supervised__badavae import BoundedAdaVae -from disent.frameworks.vae.experimental._supervised__gadavae import GuidedAdaVae -from disent.frameworks.vae.experimental._supervised__tbadavae import TripletBoundedAdaVae -from disent.frameworks.vae.experimental._supervised__tgadavae import TripletGuidedAdaVae - -# unsupervised frameworks -from disent.frameworks.vae.experimental._unsupervised__dorvae import DataOverlapRankVae -from disent.frameworks.vae.experimental._unsupervised__dotvae import DataOverlapTripletVae - -# weakly supervised frameworks -from disent.frameworks.vae.experimental._weaklysupervised__augpostriplet import AugPosTripletVae -from disent.frameworks.vae.experimental._weaklysupervised__st_adavae import SwappedTargetAdaVae -from disent.frameworks.vae.experimental._weaklysupervised__st_betavae import SwappedTargetBetaVae diff --git a/disent/frameworks/vae/experimental/_supervised__adaave_tvae.py b/disent/frameworks/vae/experimental/_supervised__adaave_tvae.py deleted file mode 100644 index f1c93cfe..00000000 --- a/disent/frameworks/vae/experimental/_supervised__adaave_tvae.py +++ /dev/null @@ -1,120 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import warnings -from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple - -from disent.util.deprecate import deprecated -from torch.distributions import Distribution -from torch.distributions import Normal - -from disent.frameworks.vae.experimental._supervised__adatvae import AdaTripletVae -from disent.frameworks.vae.experimental._supervised__adatvae import compute_ave_shared_distributions -from disent.frameworks.vae.experimental._supervised__adatvae import compute_triplet_shared_masks - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -@deprecated('Rather use the AdaNegTripletVae') -class AdaAveTripletVae(AdaTripletVae): - """ - This was a more general attempt of the ada-tvae, - that also averages representations passed to the decoder. - - just averaging in this way without augmenting the loss with - triplet, or ada_triplet is too weak of a supervision signal. - """ - - REQUIRED_OBS = 3 - - @dataclass - class cfg(AdaTripletVae.cfg): - # adavae - ada_thresh_mode: str = 'dist' # RESET OVERRIDEN VALUE - # adaave_tvae - adaave_augment_orig: bool = True # triplet over original OR averaged embeddings - adaave_decode_orig: bool = True # decode & regularize original OR averaged embeddings - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # checks - if self.cfg.ada_thresh_mode != 'dist': - warnings.warn(f'cfg.ada_thresh_mode == {repr(self.cfg.ada_thresh_mode)}. Modes other than "dist" do not work well!') - val = (self.cfg.adat_triplet_loss != 'triplet_hard_ave_all') - if self.cfg.adaave_augment_orig == val: - warnings.warn(f'cfg.adaave_augment_orig == {repr(self.cfg.adaave_augment_orig)}. Modes other than {repr(val)} do not work well!') - if self.cfg.adaave_decode_orig == False: - warnings.warn(f'cfg.adaave_decode_orig == {repr(self.cfg.adaave_decode_orig)}. Modes other than True do not work well!') - - def hook_intercept_ds(self, ds_posterior: Sequence[Normal], ds_prior: Sequence[Normal]) -> Tuple[Sequence[Distribution], Sequence[Distribution], Dict[str, Any]]: - # triplet vae intercept -- in case detached - ds_posterior, ds_prior, intercept_logs = super().hook_intercept_ds(ds_posterior, ds_prior) - - # compute shared masks, shared embeddings & averages over shared embeddings - share_masks, share_logs = compute_triplet_shared_masks(ds_posterior, cfg=self.cfg) - ds_posterior_shared, ds_posterior_shared_ave = compute_ave_shared_distributions(ds_posterior, share_masks, cfg=self.cfg) - - # DIFFERENCE FROM ADAVAE | get return values - # adavae: adaave_augment_orig == True, adaave_decode_orig == False - ds_posterior_augment = (ds_posterior if self.cfg.adaave_augment_orig else ds_posterior_shared_ave) - ds_posterior_return = (ds_posterior if self.cfg.adaave_decode_orig else ds_posterior_shared_ave) - - # save params for aug_loss hook step - self._curr_ada_loss_kwargs = dict( - share_masks=share_masks, - zs=[d.mean for d in ds_posterior], - zs_shared=[d.mean for d in ds_posterior_shared], - zs_shared_ave=[d.mean for d in ds_posterior_augment], # USUALLY: zs_params_shared_ave - ) - - return ds_posterior_return, ds_prior, { - **intercept_logs, - **share_logs, - } - - def hook_compute_ave_aug_loss(self, ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ): - """ - NOTE: we don't use input parameters here, this function will only work - if called as part of training_step or do_training_step - """ - # compute triplet loss - result = AdaTripletVae.compute_ada_triplet_loss(**self._curr_ada_loss_kwargs, cfg=self.cfg) - # cleanup temporary variables - del self._curr_ada_loss_kwargs - # we are done - return result - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py b/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py deleted file mode 100644 index 469b4099..00000000 --- a/disent/frameworks/vae/experimental/_supervised__adaneg_tvae.py +++ /dev/null @@ -1,118 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from dataclasses import dataclass -from typing import Sequence - -import torch -from torch.distributions import Normal - -from disent.nn.loss.triplet import configured_dist_triplet -from disent.nn.loss.triplet import configured_triplet -from disent.frameworks.vae._supervised__tvae import TripletVae -from disent.frameworks.vae.experimental._supervised__adatvae import compute_triplet_shared_masks -from disent.frameworks.vae.experimental._supervised__adatvae import compute_triplet_shared_masks_from_zs -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -class AdaNegTripletVae(TripletVae): - - """ - This is a condensed version of the ada_tvae and adaave_tvae, - using approximately the best settings and loss... - """ - - REQUIRED_OBS = 3 - - @dataclass - class cfg(TripletVae.cfg, AdaVae.cfg): - # adavae - ada_thresh_mode: str = 'dist' # only works for: adat_share_mask_mode == "posterior" - # ada_tvae - loss - adat_triplet_share_scale: float = 0.95 - # ada_tvae - averaging - adat_share_mask_mode: str = 'posterior' - - def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Normal], ds_prior: Sequence[Normal], zs_sampled: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]): - return self.estimate_ada_triplet_loss( - ds_posterior=ds_posterior, - cfg=self.cfg, - ) - - @staticmethod - def estimate_ada_triplet_loss_from_zs(zs: Sequence[torch.Tensor], cfg: cfg): - # compute shared masks, shared embeddings & averages over shared embeddings - share_masks, share_logs = compute_triplet_shared_masks_from_zs(zs=zs, cfg=cfg) - # compute loss - ada_triplet_loss, ada_triplet_logs = AdaNegTripletVae.compute_ada_triplet_loss(share_masks=share_masks, zs=zs, cfg=cfg) - # merge logs & return loss - return ada_triplet_loss, { - **ada_triplet_logs, - **share_logs, - } - - @staticmethod - def estimate_ada_triplet_loss(ds_posterior: Sequence[Normal], cfg: cfg): - # compute shared masks, shared embeddings & averages over shared embeddings - share_masks, share_logs = compute_triplet_shared_masks(ds_posterior, cfg=cfg) - # compute loss - ada_triplet_loss, ada_triplet_logs = AdaNegTripletVae.compute_ada_triplet_loss(share_masks=share_masks, zs=(d.mean for d in ds_posterior), cfg=cfg) - # merge logs & return loss - return ada_triplet_loss, { - **ada_triplet_logs, - **share_logs, - } - - @staticmethod - def compute_ada_triplet_loss(share_masks, zs, cfg: cfg): - # Normal Triplet Loss - (a_z, p_z, n_z) = zs - trip_loss = configured_triplet(a_z, p_z, n_z, cfg=cfg) - - # Soft Scaled Negative Triplet - (ap_share_mask, an_share_mask, pn_share_mask) = share_masks - triplet_hard_neg_ave_scaled = configured_dist_triplet( - pos_delta=a_z - p_z, - neg_delta=torch.where(an_share_mask, cfg.adat_triplet_share_scale * (a_z - n_z), (a_z - n_z)), - cfg=cfg, - ) - - return triplet_hard_neg_ave_scaled, { - 'triplet': trip_loss, - 'triplet_chosen': triplet_hard_neg_ave_scaled, - } - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__adatvae.py b/disent/frameworks/vae/experimental/_supervised__adatvae.py deleted file mode 100644 index 4fa02ba9..00000000 --- a/disent/frameworks/vae/experimental/_supervised__adatvae.py +++ /dev/null @@ -1,328 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from dataclasses import dataclass -from typing import Sequence -from typing import Tuple - -import torch -from disent.util.deprecate import deprecated -from torch.distributions import Distribution -from torch.distributions import Normal - -from disent.nn.loss.triplet import configured_dist_triplet -from disent.nn.loss.triplet import configured_triplet -from disent.frameworks.vae._supervised__tvae import TripletVae -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -@deprecated('Rather use the AdaNegTripletVae') -class AdaTripletVae(TripletVae): - - REQUIRED_OBS = 3 - - @dataclass - class cfg(TripletVae.cfg, AdaVae.cfg): - # adavae - ada_thresh_mode: str = 'dist' # only works for: adat_share_mask_mode == "posterior" - # ada_tvae - loss - adat_triplet_loss: str = 'triplet_hard_neg_ave' # should be used with a schedule! - adat_triplet_ratio: float = 1.0 - adat_triplet_soft_scale: float = 1.0 - adat_triplet_pull_weight: float = 0.1 # only works for: adat_triplet_loss == "triplet_hard_neg_ave_pull" - adat_triplet_share_scale: float = 0.95 # only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: str = 'posterior' - adat_share_ave_mode: str = 'all' # only works for: adat_triplet_loss == "triplet_hard_ave_all" - - def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Normal], ds_prior: Sequence[Normal], zs_sampled: Sequence[torch.Tensor], xs_partial_recon: Sequence[torch.Tensor], xs_targ: Sequence[torch.Tensor]): - return self.estimate_ada_triplet_loss( - ds_posterior=ds_posterior, - cfg=self.cfg, - ) - - @staticmethod - def estimate_ada_triplet_loss(ds_posterior: Sequence[Normal], cfg: cfg): - """ - zs_params and ds_posterior are convenience variables here. - - they should contain the same values - - in practice we only need one of them and can compute the other! - """ - # compute shared masks, shared embeddings & averages over shared embeddings - share_masks, share_logs = compute_triplet_shared_masks(ds_posterior, cfg=cfg) - ds_posterior_shared, ds_posterior_shared_ave = compute_ave_shared_distributions(ds_posterior, share_masks, cfg=cfg) - - # compute loss - ada_triplet_loss, ada_triplet_logs = AdaTripletVae.compute_ada_triplet_loss( - share_masks=share_masks, - zs=[d.mean for d in ds_posterior], - zs_shared=[d.mean for d in ds_posterior_shared], - zs_shared_ave=[d.mean for d in ds_posterior_shared_ave], - cfg=cfg, - ) - - return ada_triplet_loss, { - **ada_triplet_logs, - **share_logs, - } - - @staticmethod - def compute_ada_triplet_loss(share_masks: Sequence[torch.Tensor], zs: Sequence[Normal], zs_shared: Sequence[Normal], zs_shared_ave: Sequence[Normal], cfg: cfg): - - # Normal Triplet Loss - (a_z, p_z, n_z) = zs - trip_loss = configured_triplet(a_z, p_z, n_z, cfg=cfg) - - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # Hard Losses - zs_shared - # TODO: implement triplet over KL divergence rather than l1/l2 distance? - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - - # Hard Averaging Before Triplet - (ap_a_ave, ap_p_ave, an_a_ave, an_n_ave, pn_p_ave, pn_n_ave) = zs_shared - triplet_hard_ave = configured_dist_triplet(pos_delta=ap_a_ave - ap_p_ave, neg_delta=an_a_ave - an_n_ave, cfg=cfg) - triplet_hard_ave_neg = configured_dist_triplet(pos_delta=a_z - p_z, neg_delta=an_a_ave - an_n_ave, cfg=cfg) - - # Hard Averaging Before Triplet - PULLING PUSHING - (ap_share_mask, an_share_mask, pn_share_mask) = share_masks - neg_delta_push = torch.where(~an_share_mask, a_z - n_z, torch.zeros_like(a_z)) # this is the same as: an_a_ave - an_n_ave - neg_delta_pull = torch.where( an_share_mask, a_z - n_z, torch.zeros_like(a_z)) - triplet_hard_ave_neg_pull = configured_dist_push_pull_triplet(pos_delta=a_z - p_z, neg_delta=neg_delta_push, neg_delta_pull=neg_delta_pull, cfg=cfg) - - # Hard All Averaging Before Triplet - (a_ave, p_ave, n_ave) = zs_shared_ave - triplet_all_hard_ave = configured_dist_triplet(pos_delta=a_ave-p_ave, neg_delta=a_ave-n_ave, cfg=cfg) - - # Soft Scaled Negative Triplet - triplet_hard_neg_ave_scaled = configured_dist_triplet( - pos_delta=a_z - p_z, - neg_delta=torch.where(an_share_mask, cfg.adat_triplet_share_scale * (a_z - n_z), (a_z - n_z)), - cfg=cfg, - ) - - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # Soft Losses - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - - # Individual Pair Averaging Losses - _soft_ap_loss = configured_soft_ave_loss(share_mask=ap_share_mask, delta=a_z - p_z, cfg=cfg) - _soft_an_loss = configured_soft_ave_loss(share_mask=an_share_mask, delta=a_z - n_z, cfg=cfg) - _soft_pn_loss = configured_soft_ave_loss(share_mask=pn_share_mask, delta=p_z - n_z, cfg=cfg) - - # soft losses - soft_loss_an = (_soft_an_loss) - soft_loss_an_ap = (_soft_an_loss + _soft_ap_loss) / 2 - soft_loss_an_ap_pn = (_soft_an_loss + _soft_ap_loss + _soft_pn_loss) / 3 - - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # Return - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - - losses = { - 'triplet': trip_loss, - # soft ave - 'triplet_soft_ave_neg': trip_loss + soft_loss_an, - 'triplet_soft_ave_p_n': trip_loss + soft_loss_an_ap, - 'triplet_soft_ave_all': trip_loss + soft_loss_an_ap_pn, - # hard ave - 'triplet_hard_ave': torch.lerp(trip_loss, triplet_hard_ave, weight=cfg.adat_triplet_ratio), - 'triplet_hard_neg_ave': torch.lerp(trip_loss, triplet_hard_ave_neg, weight=cfg.adat_triplet_ratio), - 'triplet_hard_neg_ave_pull': torch.lerp(trip_loss, triplet_hard_ave_neg_pull, weight=cfg.adat_triplet_ratio), - 'triplet_hard_ave_all': torch.lerp(trip_loss, triplet_all_hard_ave, weight=cfg.adat_triplet_ratio), - # scaled - 'triplet_hard_neg_ave_scaled': torch.lerp(trip_loss, triplet_hard_neg_ave_scaled, weight=cfg.adat_triplet_ratio), - } - - return losses[cfg.adat_triplet_loss], { - 'triplet': trip_loss, - 'triplet_chosen': losses[cfg.adat_triplet_loss], - } - - -# ========================================================================= # -# Ada-TVae # -# ========================================================================= # - - -def dist_push_pull_triplet(pos_delta, neg_delta, neg_delta_pull, margin_max=1., p=1, pull_weight=1.): - """ - Pushing Pulling Triplet Loss - - should match standard triplet loss if pull_weight=0. - """ - p_dist = torch.norm(pos_delta, p=p, dim=-1) - n_dist = torch.norm(neg_delta, p=p, dim=-1) - n_dist_pull = torch.norm(neg_delta_pull, p=p, dim=-1) - loss = torch.clamp_min(p_dist - n_dist + margin_max + pull_weight * n_dist_pull, 0) - return loss.mean() - - -def configured_dist_push_pull_triplet(pos_delta, neg_delta, neg_delta_pull, cfg: AdaTripletVae.cfg): - """ - required config params: - - cfg.triplet_margin_max: (0, inf) - - cfg.triplet_p: 1 or 2 - - cfg.triplet_scale: [0, inf) - - cfg.adat_triplet_pull_weight: [0, 1] - """ - return dist_push_pull_triplet( - pos_delta=pos_delta, neg_delta=neg_delta, neg_delta_pull=neg_delta_pull, - margin_max=cfg.triplet_margin_max, p=cfg.triplet_p, pull_weight=cfg.adat_triplet_pull_weight, - ) * cfg.triplet_scale - - -def soft_ave_loss(share_mask, delta): - return torch.norm(torch.where(share_mask, delta, torch.zeros_like(delta)), p=2, dim=-1).mean() - - -def configured_soft_ave_loss(share_mask, delta, cfg: AdaTripletVae.cfg): - """ - required config params: - - cfg.triplet_scale: [0, inf) - - cfg.adat_triplet_soft_scale: [0, inf) - """ - return soft_ave_loss(share_mask=share_mask, delta=delta) * (cfg.adat_triplet_soft_scale * cfg.triplet_scale) - - -# ========================================================================= # -# AveAda-TVAE # -# ========================================================================= # - - -def compute_triplet_shared_masks_from_zs(zs: Sequence[torch.Tensor], cfg): - """ - required config params: - - cfg.ada_thresh_ratio: - """ - a_z, p_z, n_z = zs - # shared elements that need to be averaged, computed per pair in the batch. - ap_share_mask = AdaVae.compute_shared_mask_from_zs(a_z, p_z, ratio=cfg.ada_thresh_ratio) - an_share_mask = AdaVae.compute_shared_mask_from_zs(a_z, n_z, ratio=cfg.ada_thresh_ratio) - pn_share_mask = AdaVae.compute_shared_mask_from_zs(p_z, n_z, ratio=cfg.ada_thresh_ratio) - # return values - share_masks = (ap_share_mask, an_share_mask, pn_share_mask) - return share_masks, { - 'ap_shared': ap_share_mask.sum(dim=1).float().mean(), - 'an_shared': an_share_mask.sum(dim=1).float().mean(), - 'pn_shared': pn_share_mask.sum(dim=1).float().mean(), - } - - -def compute_triplet_shared_masks(ds_posterior: Sequence[Distribution], cfg: AdaTripletVae.cfg): - """ - required config params: - - cfg.ada_thresh_ratio: - - cfg.ada_thresh_mode: "kl", "symmetric_kl", "dist", "sampled_dist" - : only applies if cfg.ada_share_mask_mode=="posterior" - - cfg.adat_share_mask_mode: "posterior", "sample", "sample_each" - """ - a_posterior, p_posterior, n_posterior = ds_posterior - - # shared elements that need to be averaged, computed per pair in the batch. - if cfg.adat_share_mask_mode == 'posterior': - ap_share_mask = AdaVae.compute_shared_mask_from_posteriors(a_posterior, p_posterior, thresh_mode=cfg.ada_thresh_mode, ratio=cfg.ada_thresh_ratio) - an_share_mask = AdaVae.compute_shared_mask_from_posteriors(a_posterior, n_posterior, thresh_mode=cfg.ada_thresh_mode, ratio=cfg.ada_thresh_ratio) - pn_share_mask = AdaVae.compute_shared_mask_from_posteriors(p_posterior, n_posterior, thresh_mode=cfg.ada_thresh_mode, ratio=cfg.ada_thresh_ratio) - elif cfg.adat_share_mask_mode == 'sample': - a_z_sample, p_z_sample, n_z_sample = a_posterior.rsample(), p_posterior.rsample(), n_posterior.rsample() - ap_share_mask = AdaVae.compute_shared_mask_from_zs(a_z_sample, p_z_sample, ratio=cfg.ada_thresh_ratio) - an_share_mask = AdaVae.compute_shared_mask_from_zs(a_z_sample, n_z_sample, ratio=cfg.ada_thresh_ratio) - pn_share_mask = AdaVae.compute_shared_mask_from_zs(p_z_sample, n_z_sample, ratio=cfg.ada_thresh_ratio) - elif cfg.adat_share_mask_mode == 'sample_each': - ap_share_mask = AdaVae.compute_shared_mask_from_zs(a_posterior.rsample(), p_posterior.rsample(), ratio=cfg.ada_thresh_ratio) - an_share_mask = AdaVae.compute_shared_mask_from_zs(a_posterior.rsample(), n_posterior.rsample(), ratio=cfg.ada_thresh_ratio) - pn_share_mask = AdaVae.compute_shared_mask_from_zs(p_posterior.rsample(), n_posterior.rsample(), ratio=cfg.ada_thresh_ratio) - else: - raise KeyError(f'Invalid cfg.adat_share_mask_mode={repr(cfg.adat_share_mask_mode)}') - - # return values - share_masks = (ap_share_mask, an_share_mask, pn_share_mask) - return share_masks, { - 'ap_shared': ap_share_mask.sum(dim=1).float().mean(), - 'an_shared': an_share_mask.sum(dim=1).float().mean(), - 'pn_shared': pn_share_mask.sum(dim=1).float().mean(), - } - - -def compute_ave_shared_distributions(ds_posterior: Sequence[Normal], share_masks: Sequence[torch.Tensor], cfg: AdaTripletVae.cfg) -> Tuple[Sequence[Normal], Sequence[Normal]]: - """ - required config params: - - cfg.ada_average_mode: "gvae", "ml-vae" - - cfg.adat_share_ave_mode: "all", "pos_neg", "pos", "neg" - """ - a_posterior, p_posterior, n_posterior = ds_posterior - ap_share_mask, an_share_mask, pn_share_mask = share_masks - - # compute shared embeddings - ave_ap_a_posterior, ave_ap_p_posterior = AdaVae.make_shared_posteriors(a_posterior, p_posterior, ap_share_mask, average_mode=cfg.ada_average_mode) - ave_an_a_posterior, ave_an_n_posterior = AdaVae.make_shared_posteriors(a_posterior, n_posterior, an_share_mask, average_mode=cfg.ada_average_mode) - ave_pn_p_posterior, ave_pn_n_posterior = AdaVae.make_shared_posteriors(p_posterior, n_posterior, pn_share_mask, average_mode=cfg.ada_average_mode) - - # compute averaged shared embeddings - if cfg.adat_share_ave_mode == 'all': - ave_a_posterior = AdaVae.compute_average_distribution(ave_ap_a_posterior, ave_an_a_posterior, average_mode=cfg.ada_average_mode) - ave_p_posterior = AdaVae.compute_average_distribution(ave_ap_p_posterior, ave_pn_p_posterior, average_mode=cfg.ada_average_mode) - ave_n_posterior = AdaVae.compute_average_distribution(ave_an_n_posterior, ave_pn_n_posterior, average_mode=cfg.ada_average_mode) - elif cfg.adat_share_ave_mode == 'pos_neg': - ave_a_posterior = AdaVae.compute_average_distribution(ave_ap_a_posterior, ave_an_a_posterior, average_mode=cfg.ada_average_mode) - ave_p_posterior = ave_ap_p_posterior - ave_n_posterior = ave_an_n_posterior - elif cfg.adat_share_ave_mode == 'pos': - ave_a_posterior = ave_ap_a_posterior - ave_p_posterior = ave_ap_p_posterior - ave_n_posterior = n_posterior - elif cfg.adat_share_ave_mode == 'neg': - ave_a_posterior = ave_an_a_posterior - ave_p_posterior = p_posterior - ave_n_posterior = ave_an_n_posterior - else: - raise KeyError(f'Invalid cfg.adat_share_ave_mode={repr(cfg.adat_share_ave_mode)}') - - ds_posterior_shared = ( - ave_ap_a_posterior, ave_ap_p_posterior, # a & p - ave_an_a_posterior, ave_an_n_posterior, # a & n - ave_pn_p_posterior, ave_pn_n_posterior, # p & n - ) - - ds_posterior_shared_ave = ( - ave_a_posterior, - ave_p_posterior, - ave_n_posterior - ) - - # return values - return ds_posterior_shared, ds_posterior_shared_ave - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__badavae.py b/disent/frameworks/vae/experimental/_supervised__badavae.py deleted file mode 100644 index ccc77a54..00000000 --- a/disent/frameworks/vae/experimental/_supervised__badavae.py +++ /dev/null @@ -1,121 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple - -import torch -from torch.distributions import Distribution - -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -class BoundedAdaVae(AdaVae): - - REQUIRED_OBS = 3 - - @dataclass - class cfg(AdaVae.cfg): - pass - - def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]) -> Tuple[Sequence[Distribution], Sequence[Distribution], Dict[str, Any]]: - a_posterior, p_posterior, n_posterior = ds_posterior - - # get deltas - a_p_deltas = AdaVae.compute_deltas_from_posteriors(a_posterior, p_posterior, thresh_mode=self.cfg.ada_thresh_mode) - a_n_deltas = AdaVae.compute_deltas_from_posteriors(a_posterior, n_posterior, thresh_mode=self.cfg.ada_thresh_mode) - - # shared elements that need to be averaged, computed per pair in the batch. - old_p_shared_mask = AdaVae.estimate_shared_mask(a_p_deltas, ratio=self.cfg.ada_thresh_ratio) - old_n_shared_mask = AdaVae.estimate_shared_mask(a_n_deltas, ratio=self.cfg.ada_thresh_ratio) - - # modify threshold based on criterion and recompute if necessary - # CORE of this approach! - p_shared_mask, n_shared_mask = compute_constrained_masks(a_p_deltas, old_p_shared_mask, a_n_deltas, old_n_shared_mask) - - # make averaged variables - # TODO: this will probably be better if it is the negative involed - # TODO: this can be merged with the gadavae/badavae - ave_ap_a_posterior, ave_ap_p_posterior = AdaVae.make_shared_posteriors(a_posterior, p_posterior, p_shared_mask, average_mode=self.cfg.ada_average_mode) - - # TODO: n_z_params should not be here! this does not match the original version - # number of loss elements is not 2 like the original - # - recons gets 2 items, p & a only - # - reg gets 2 items, p & a only - new_ds_posterior = (ave_ap_a_posterior, ave_ap_p_posterior, n_posterior) - - # return new args & generate logs - # -- we only return 2 parameters a & p, not n - return new_ds_posterior, ds_prior, { - 'p_shared_before': old_p_shared_mask.sum(dim=1).float().mean(), - 'p_shared_after': p_shared_mask.sum(dim=1).float().mean(), - 'n_shared_before': old_n_shared_mask.sum(dim=1).float().mean(), - 'n_shared_after': n_shared_mask.sum(dim=1).float().mean(), - } - - -# ========================================================================= # -# HELPER # -# ========================================================================= # - - -def compute_constrained_masks(p_kl_deltas, p_shared_mask, n_kl_deltas, n_shared_mask): - # number of changed factors - p_shared_num = torch.sum(p_shared_mask, dim=1, keepdim=True) - n_shared_num = torch.sum(n_shared_mask, dim=1, keepdim=True) - - # POSITIVE SHARED MASK - # order from smallest to largest - p_sort_indices = torch.argsort(p_kl_deltas, dim=1) - # p_shared should be at least n_shared - new_p_shared_num = torch.max(p_shared_num, n_shared_num) - - # NEGATIVE SHARED MASK - # order from smallest to largest - n_sort_indices = torch.argsort(n_kl_deltas, dim=1) - # n_shared should be at most p_shared - new_n_shared_num = torch.min(p_shared_num, n_shared_num) - - # COMPUTE NEW MASKS - new_p_shared_mask = torch.zeros_like(p_shared_mask) - new_n_shared_mask = torch.zeros_like(n_shared_mask) - for i, (new_shared_p, new_shared_n) in enumerate(zip(new_p_shared_num, new_n_shared_num)): - new_p_shared_mask[i, p_sort_indices[i, :new_shared_p]] = True - new_n_shared_mask[i, n_sort_indices[i, :new_shared_n]] = True - - # return masks - return new_p_shared_mask, new_n_shared_mask - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__gadavae.py b/disent/frameworks/vae/experimental/_supervised__gadavae.py deleted file mode 100644 index a7e7f381..00000000 --- a/disent/frameworks/vae/experimental/_supervised__gadavae.py +++ /dev/null @@ -1,102 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Tuple - -from torch.distributions import Distribution - -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae -from disent.frameworks.vae.experimental._supervised__badavae import compute_constrained_masks - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -class GuidedAdaVae(AdaVae): - - REQUIRED_OBS = 3 - - @dataclass - class cfg(AdaVae.cfg): - gada_anchor_ave_mode: str = 'average' - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # how the anchor is averaged - assert cfg.gada_anchor_ave_mode in {'thresh', 'average'} - - def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequence[Distribution]) -> Tuple[Sequence[Distribution], Sequence[Distribution], Dict[str, Any]]: - """ - *NB* arguments must satisfy: d(l, l2) < d(l, l3) [positive dist < negative dist] - - This function assumes that the distance between labels l, l2, l3 - corresponding to z, z2, z3 satisfy the criteria d(l, l2) < d(l, l3) - ie. l2 is the positive sample, l3 is the negative sample - """ - a_posterior, p_posterior, n_posterior = ds_posterior - - # get deltas - a_p_deltas = AdaVae.compute_deltas_from_posteriors(a_posterior, p_posterior, thresh_mode=self.cfg.ada_thresh_mode) - a_n_deltas = AdaVae.compute_deltas_from_posteriors(a_posterior, n_posterior, thresh_mode=self.cfg.ada_thresh_mode) - - # shared elements that need to be averaged, computed per pair in the batch. - old_p_shared_mask = AdaVae.estimate_shared_mask(a_p_deltas, ratio=self.cfg.ada_thresh_ratio) - old_n_shared_mask = AdaVae.estimate_shared_mask(a_n_deltas, ratio=self.cfg.ada_thresh_ratio) - - # modify threshold based on criterion and recompute if necessary - # CORE of this approach! - p_shared_mask, n_shared_mask = compute_constrained_masks(a_p_deltas, old_p_shared_mask, a_n_deltas, old_n_shared_mask) - - # make averaged variables - # TODO: this can be merged with the gadavae/badavae - ave_ap_a_posterior, ave_ap_p_posterior = AdaVae.make_shared_posteriors(a_posterior, p_posterior, p_shared_mask, average_mode=self.cfg.ada_average_mode) - ave_an_a_posterior, ave_an_n_posterior = AdaVae.make_shared_posteriors(a_posterior, n_posterior, n_shared_mask, average_mode=self.cfg.ada_average_mode) - ave_a_posterior = AdaVae.compute_average_distribution(ave_ap_a_posterior, ave_an_a_posterior, average_mode=self.cfg.ada_average_mode) - - # compute anchor average using the adaptive threshold | TODO: this doesn't really make sense - anchor_ave_logs = {} - if self.cfg.gada_anchor_ave_mode == 'thresh': - ave_shared_mask = p_shared_mask * n_shared_mask - ave_params, _ = AdaVae.make_shared_posteriors(a_posterior, ave_a_posterior, ave_shared_mask, average_mode=self.cfg.ada_average_mode) - anchor_ave_logs['ave_shared'] = ave_shared_mask.sum(dim=1).float().mean() - - new_ds_posterior = ave_a_posterior, ave_ap_p_posterior, ave_an_n_posterior - - return new_ds_posterior, ds_prior, { - 'p_shared_before': old_p_shared_mask.sum(dim=1).float().mean(), - 'p_shared_after': p_shared_mask.sum(dim=1).float().mean(), - 'n_shared_before': old_n_shared_mask.sum(dim=1).float().mean(), - 'n_shared_after': n_shared_mask.sum(dim=1).float().mean(), - **anchor_ave_logs, - } - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__tbadavae.py b/disent/frameworks/vae/experimental/_supervised__tbadavae.py deleted file mode 100644 index 9e5caf8d..00000000 --- a/disent/frameworks/vae/experimental/_supervised__tbadavae.py +++ /dev/null @@ -1,51 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass - -from disent.frameworks.vae.experimental._supervised__badavae import BoundedAdaVae -from disent.nn.loss.triplet import compute_triplet_loss -from disent.nn.loss.triplet import TripletLossConfig - - -# ========================================================================= # -# tbadavae # -# ========================================================================= # - - -class TripletBoundedAdaVae(BoundedAdaVae): - - REQUIRED_OBS = 3 - - @dataclass - class cfg(BoundedAdaVae.cfg, TripletLossConfig): - pass - - def hook_compute_ave_aug_loss(self, ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ): - return compute_triplet_loss(zs=[d.mean for d in ds_posterior], cfg=self.cfg) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_supervised__tgadavae.py b/disent/frameworks/vae/experimental/_supervised__tgadavae.py deleted file mode 100644 index 0739e751..00000000 --- a/disent/frameworks/vae/experimental/_supervised__tgadavae.py +++ /dev/null @@ -1,51 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass - -from disent.frameworks.vae.experimental._supervised__gadavae import GuidedAdaVae -from disent.nn.loss.triplet import compute_triplet_loss -from disent.nn.loss.triplet import TripletLossConfig - - -# ========================================================================= # -# tgadavae # -# ========================================================================= # - - -class TripletGuidedAdaVae(GuidedAdaVae): - - REQUIRED_OBS = 3 - - @dataclass - class cfg(GuidedAdaVae.cfg, TripletLossConfig): - pass - - def hook_compute_ave_aug_loss(self, ds_posterior, ds_prior, zs_sampled, xs_partial_recon, xs_targ): - return compute_triplet_loss(zs=[d.mean for d in ds_posterior], cfg=self.cfg) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py b/disent/frameworks/vae/experimental/_unsupervised__dorvae.py deleted file mode 100644 index d8b139f4..00000000 --- a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py +++ /dev/null @@ -1,169 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass -from typing import final -from typing import Optional -from typing import Sequence - -import torch -from torch.distributions import Normal - -from disent.frameworks.helper.reconstructions import make_reconstruction_loss -from disent.frameworks.helper.reconstructions import ReconLossHandler -from disent.frameworks.vae._supervised__tvae import TripletVae -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae -from disent.nn.loss.softsort import torch_mse_rank_loss -from disent.nn.loss.softsort import spearman_rank_loss - - -# ========================================================================= # -# tvae # -# ========================================================================= # - - -class DataOverlapRankVae(TripletVae): - """ - This converges really well! - - but doesn't introduce axis alignment as well if there is no additional - inward pressure term like triplet to move representations closer together - """ - - REQUIRED_OBS = 1 - - @dataclass - class cfg(TripletVae.cfg): - # compatibility - ada_thresh_mode: str = 'dist' # kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: float = 0.5 - adat_triplet_share_scale: float = 0.95 - # OVERLAP VAE - overlap_loss: Optional[str] = None - overlap_num: int = 1024 - # AUGMENT - overlap_augment_mode: str = 'none' - overlap_augment: Optional[dict] = None - # REPRESENTATIONS - overlap_repr: str = 'deterministic' # deterministic, stochastic - overlap_rank_mode: str = 'spearman_rank' # spearman_rank, mse_rank - overlap_inward_pressure_masked: bool = False - overlap_inward_pressure_scale: float = 0.1 - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - # TODO: duplicate code - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # initialise - if self.cfg.overlap_augment_mode != 'none': - assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.' - # set augment and instantiate if needed - self._augment = None - if isinstance(self._augment, dict): - import hydra - self._augment = hydra.utils.instantiate(self.cfg.overlap_augment) - assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' - # get overlap loss - overlap_loss = self.cfg.overlap_loss if (self.cfg.overlap_loss is not None) else self.cfg.recon_loss - self.__overlap_handler: ReconLossHandler = make_reconstruction_loss(overlap_loss, reduction='mean') - - @final - @property - def overlap_handler(self) -> ReconLossHandler: - return self.__overlap_handler - - def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Normal], ds_prior, zs_sampled, xs_partial_recon, xs_targ: Sequence[torch.Tensor]): - # ++++++++++++++++++++++++++++++++++++++++++ # - # 1. augment batch - (x_targ_orig,) = xs_targ - with torch.no_grad(): - (x_targ,) = self.augment_triplet_targets(xs_targ) - (d_posterior,) = ds_posterior - (z_sampled,) = zs_sampled - # 2. generate random pairs -- this does not generate unique pairs - a_idxs, p_idxs = torch.randint(len(x_targ), size=(2, self.cfg.overlap_num), device=x_targ.device) - # ++++++++++++++++++++++++++++++++++++++++++ # - # compute image distances - with torch.no_grad(): - ap_recon_dists = self.overlap_handler.compute_pairwise_loss(x_targ[a_idxs], x_targ[p_idxs]) - # ++++++++++++++++++++++++++++++++++++++++++ # - # get representations - if self.cfg.overlap_repr == 'deterministic': - a_z, p_z = d_posterior.loc[a_idxs], d_posterior.loc[p_idxs] - elif self.cfg.overlap_repr == 'stochastic': - a_z, p_z = z_sampled[a_idxs], z_sampled[p_idxs] - else: - raise KeyError(f'invalid overlap_repr mode: {repr(self.cfg.overlap_repr)}') - # DISENTANGLE! - # compute adaptive mask & weight deltas - a_posterior = Normal(d_posterior.loc[a_idxs], d_posterior.scale[a_idxs]) - p_posterior = Normal(d_posterior.loc[p_idxs], d_posterior.scale[p_idxs]) - share_mask = AdaVae.compute_shared_mask_from_posteriors(a_posterior, p_posterior, thresh_mode=self.cfg.ada_thresh_mode, ratio=self.cfg.ada_thresh_ratio) - deltas = torch.where(share_mask, self.cfg.adat_triplet_share_scale * (a_z - p_z), (a_z - p_z)) - # compute representation distances - ap_repr_dists = torch.abs(deltas).sum(dim=-1) - # ++++++++++++++++++++++++++++++++++++++++++ # - if self.cfg.overlap_rank_mode == 'mse_rank': - loss = torch_mse_rank_loss(ap_repr_dists, ap_recon_dists.detach(), dims=-1, reduction='mean') - loss_logs = {'mse_rank_loss': loss} - elif self.cfg.overlap_rank_mode == 'spearman_rank': - loss = - spearman_rank_loss(ap_repr_dists, ap_recon_dists.detach(), nan_to_num=True) - loss_logs = {'spearman_rank_loss': loss} - else: - raise KeyError(f'invalid overlap_rank_mode: {repr(self.cfg.overlap_repr)}') - # ++++++++++++++++++++++++++++++++++++++++++ # - # inward pressure - if self.cfg.overlap_inward_pressure_masked: - in_deltas = torch.abs(deltas) * share_mask - else: - in_deltas = torch.abs(deltas) - # compute inward pressure - inward_pressure = self.cfg.overlap_inward_pressure_scale * in_deltas.mean() - loss += inward_pressure - # ++++++++++++++++++++++++++++++++++++++++++ # - # return the loss - return loss, { - **loss_logs, - 'inward_pressure': inward_pressure, - } - - def augment_triplet_targets(self, xs_targ): - # TODO: duplicate code - if self.cfg.overlap_augment_mode == 'none': - aug_xs_targ = xs_targ - elif (self.cfg.overlap_augment_mode == 'augment') or (self.cfg.overlap_augment_mode == 'augment_each'): - # recreate augment each time - if self.cfg.overlap_augment_mode == 'augment_each': - import hydra - self._augment = hydra.utils.instantiate(self.cfg.overlap_augment) - # augment on correct device - aug_xs_targ = [self._augment(x_targ) for x_targ in xs_targ] - # checks - assert all(a.shape == b.shape for a, b in zip(xs_targ, aug_xs_targ)) - else: - raise KeyError(f'invalid cfg.overlap_augment_mode={repr(self.cfg.overlap_augment_mode)}') - return aug_xs_targ - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_unsupervised__dotvae.py b/disent/frameworks/vae/experimental/_unsupervised__dotvae.py deleted file mode 100644 index 4ca79e19..00000000 --- a/disent/frameworks/vae/experimental/_unsupervised__dotvae.py +++ /dev/null @@ -1,222 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from dataclasses import dataclass -from typing import final -from typing import Optional -from typing import Sequence - -import torch -from torch.distributions import Normal - -from disent.frameworks.helper.reconstructions import make_reconstruction_loss -from disent.frameworks.helper.reconstructions import ReconLossHandler -from disent.frameworks.vae.experimental._supervised__adaneg_tvae import AdaNegTripletVae -from disent.nn.loss.triplet_mining import configured_idx_mine - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Mixin # -# ========================================================================= # - - -class DataOverlapMixin(object): - - # should be inherited by the config on the child class - @dataclass - class cfg: - # override from AE - recon_loss: str = 'mse' - # OVERLAP VAE - overlap_loss: Optional[str] = None # if None, use the value from recon_loss - overlap_num: int = 1024 - overlap_mine_ratio: float = 0.1 - overlap_mine_triplet_mode: str = 'none' - # AUGMENT - overlap_augment_mode: str = 'none' - overlap_augment: Optional[dict] = None - - # private properties - # - since this class does not have a constructor, it - # provides the `init_data_overlap_mixin` method, which - # should be called inside the constructor of the child class - _augment: callable - _overlap_handler: ReconLossHandler - _init: bool - - def init_data_overlap_mixin(self): - if hasattr(self, '_init'): - raise RuntimeError(f'{DataOverlapMixin.__name__} on {self.__class__.__name__} was initialised more than once!') - self._init = True - # initialise - if self.cfg.overlap_augment_mode != 'none': - assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.' - # set augment and instantiate if needed - self._augment = None - if isinstance(self._augment, dict): - import hydra - self._augment = hydra.utils.instantiate(self.cfg.overlap_augment) - assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' - # get overlap loss - overlap_loss = self.cfg.overlap_loss if (self.cfg.overlap_loss is not None) else self.cfg.recon_loss - self._overlap_handler: ReconLossHandler = make_reconstruction_loss(overlap_loss, reduction='mean') - # delete this property, we only ever want to be able to call this once! - - @final - @property - def overlap_handler(self) -> ReconLossHandler: - return self._overlap_handler - - def overlap_swap_triplet_idxs(self, x_targ, a_idxs, p_idxs, n_idxs): - xs_targ = [x_targ[idxs] for idxs in (a_idxs, p_idxs, n_idxs)] - # CORE: order the latent variables for triplet - swap_mask = self.overlap_swap_mask(xs_targ=xs_targ) - # swap all idxs - swapped_a_idxs = a_idxs - swapped_p_idxs = torch.where(swap_mask, n_idxs, p_idxs) - swapped_n_idxs = torch.where(swap_mask, p_idxs, n_idxs) - # return values - return swapped_a_idxs, swapped_p_idxs, swapped_n_idxs - - def overlap_swap_mask(self, xs_targ: Sequence[torch.Tensor]) -> torch.Tensor: - # get variables - a_x_targ_OLD, p_x_targ_OLD, n_x_targ_OLD = xs_targ - # CORE OF THIS APPROACH - # ++++++++++++++++++++++++++++++++++++++++++ # - # calculate which are wrong! - # TODO: add more loss functions, like perceptual & others - with torch.no_grad(): - a_p_losses = self.overlap_handler.compute_pairwise_loss(a_x_targ_OLD, p_x_targ_OLD) # (B, C, H, W) -> (B,) - a_n_losses = self.overlap_handler.compute_pairwise_loss(a_x_targ_OLD, n_x_targ_OLD) # (B, C, H, W) -> (B,) - swap_mask = (a_p_losses > a_n_losses) # (B,) - # ++++++++++++++++++++++++++++++++++++++++++ # - return swap_mask - - @torch.no_grad() - def augment_batch(self, x_targ): - if self.cfg.overlap_augment_mode == 'none': - aug_x_targ = x_targ - elif self.cfg.overlap_augment_mode in ('augment', 'augment_each'): - # recreate augment each time - if self.cfg.overlap_augment_mode == 'augment_each': - self._augment = instantiate_recursive(self.cfg.overlap_augment) - # augment on correct device - aug_x_targ = self._augment(x_targ) - else: - raise KeyError(f'invalid cfg.overlap_augment_mode={repr(self.cfg.overlap_augment_mode)}') - # checks - assert x_targ.shape == aug_x_targ.shape - return aug_x_targ - - def mine_triplets(self, x_targ, a_idxs, p_idxs, n_idxs): - return configured_idx_mine( - x_targ=x_targ, - a_idxs=a_idxs, - p_idxs=p_idxs, - n_idxs=n_idxs, - cfg=self.cfg, - pairwise_loss_fn=self.overlap_handler.compute_pairwise_loss, - ) - - def random_mined_triplets(self, x_targ_orig: torch.Tensor): - # ++++++++++++++++++++++++++++++++++++++++++ # - # 1. augment batch - aug_x_targ = self.augment_batch(x_targ_orig) - # 2. generate random triples -- this does not generate unique pairs - a_idxs, p_idxs, n_idxs = torch.randint(len(aug_x_targ), size=(3, min(self.cfg.overlap_num, len(aug_x_targ)**3)), device=aug_x_targ.device) - # ++++++++++++++++++++++++++++++++++++++++++ # - # self.debug(x_targ_orig, x_targ, a_idxs, p_idxs, n_idxs) - # ++++++++++++++++++++++++++++++++++++++++++ # - # TODO: this can be merged into a single function -- inefficient currently with deltas computed twice - # 3. reorder random triples - a_idxs, p_idxs, n_idxs = self.overlap_swap_triplet_idxs(aug_x_targ, a_idxs, p_idxs, n_idxs) - # 4. mine random triples - a_idxs, p_idxs, n_idxs = self.mine_triplets(aug_x_targ, a_idxs, p_idxs, n_idxs) - # ++++++++++++++++++++++++++++++++++++++++++ # - return a_idxs, p_idxs, n_idxs - - # def debug(self, x_targ_orig, x_targ, a_idxs, p_idxs, n_idxs): - # a_p_overlap_orig = - self.recon_handler.compute_unreduced_loss(x_targ_orig[a_idxs], x_targ_orig[p_idxs]).mean(dim=(-3, -2, -1)) # (B, C, H, W) -> (B,) - # a_n_overlap_orig = - self.recon_handler.compute_unreduced_loss(x_targ_orig[a_idxs], x_targ_orig[n_idxs]).mean(dim=(-3, -2, -1)) # (B, C, H, W) -> (B,) - # a_p_overlap = - self.recon_handler.compute_unreduced_loss(x_targ[a_idxs], x_targ[p_idxs]).mean(dim=(-3, -2, -1)) # (B, C, H, W) -> (B,) - # a_n_overlap = - self.recon_handler.compute_unreduced_loss(x_targ[a_idxs], x_targ[n_idxs]).mean(dim=(-3, -2, -1)) # (B, C, H, W) -> (B,) - # a_p_overlap_mul = - (a_p_overlap_orig * a_p_overlap) - # a_n_overlap_mul = - (a_n_overlap_orig * a_n_overlap) - # # check number of things - # (up_values_orig, up_counts_orig) = torch.unique(a_p_overlap_orig, sorted=True, return_inverse=False, return_counts=True) - # (un_values_orig, un_counts_orig) = torch.unique(a_n_overlap_orig, sorted=True, return_inverse=False, return_counts=True) - # (up_values, up_counts) = torch.unique(a_p_overlap, sorted=True, return_inverse=False, return_counts=True) - # (un_values, un_counts) = torch.unique(a_n_overlap, sorted=True, return_inverse=False, return_counts=True) - # (up_values_mul, up_counts_mul) = torch.unique(a_p_overlap_mul, sorted=True, return_inverse=False, return_counts=True) - # (un_values_mul, un_counts_mul) = torch.unique(a_n_overlap_mul, sorted=True, return_inverse=False, return_counts=True) - # # plot! - # plt.scatter(up_values_orig.detach().cpu(), torch.cumsum(up_counts_orig, dim=-1).detach().cpu()) - # plt.scatter(un_values_orig.detach().cpu(), torch.cumsum(un_counts_orig, dim=-1).detach().cpu()) - # plt.scatter(up_values.detach().cpu(), torch.cumsum(up_counts, dim=-1).detach().cpu()) - # plt.scatter(un_values.detach().cpu(), torch.cumsum(un_counts, dim=-1).detach().cpu()) - # plt.scatter(up_values_mul.detach().cpu(), torch.cumsum(up_counts_mul, dim=-1).detach().cpu()) - # plt.scatter(un_values_mul.detach().cpu(), torch.cumsum(un_counts_mul, dim=-1).detach().cpu()) - # plt.show() - # time.sleep(10) - - -# ========================================================================= # -# Data Overlap Triplet VAE # -# ========================================================================= # - - -class DataOverlapTripletVae(AdaNegTripletVae, DataOverlapMixin): - - REQUIRED_OBS = 1 - - @dataclass - class cfg(AdaNegTripletVae.cfg, DataOverlapMixin.cfg): - pass - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # initialise mixin - self.init_data_overlap_mixin() - - def hook_compute_ave_aug_loss(self, ds_posterior: Sequence[Normal], ds_prior, zs_sampled, xs_partial_recon, xs_targ: Sequence[torch.Tensor]): - [d_posterior], [x_targ_orig] = ds_posterior, xs_targ - # 1. randomly generate and mine triplets using augmented versions of the inputs - a_idxs, p_idxs, n_idxs = self.random_mined_triplets(x_targ_orig=x_targ_orig) - # 2. compute triplet loss - loss, loss_log = AdaNegTripletVae.estimate_ada_triplet_loss( - ds_posterior=[Normal(d_posterior.loc[idxs], d_posterior.scale[idxs]) for idxs in (a_idxs, p_idxs, n_idxs)], - cfg=self.cfg, - ) - return loss, { - **loss_log, - } - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py b/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py deleted file mode 100644 index bad6354a..00000000 --- a/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py +++ /dev/null @@ -1,82 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import warnings -from dataclasses import dataclass -from typing import Union - -import torch - -from disent.frameworks.vae._supervised__tvae import TripletVae - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Guided Ada Vae # -# ========================================================================= # - - -class AugPosTripletVae(TripletVae): - - REQUIRED_OBS = 2 # third obs is generated from augmentations - - @dataclass - class cfg(TripletVae.cfg): - overlap_augment: Union[dict, callable] = None - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - # set augment and instantiate if needed - self._augment = self.cfg.overlap_augment - if isinstance(self._augment, dict): - import hydra - self._augment = hydra.utils.instantiate(self._augment) - # get default if needed - if self._augment is None: - self._augment = torch.nn.Identity() - warnings.warn(f'{self.__class__.__name__}, no overlap_augment was specified, defaulting to nn.Identity which WILL break things!') - # checks! - assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' - - def do_training_step(self, batch, batch_idx): - (a_x, n_x), (a_x_targ, n_x_targ) = self._get_xs_and_targs(batch, batch_idx) - - # generate augmented items - with torch.no_grad(): - p_x_targ = a_x_targ - p_x = self._augment(a_x) - # a_x = self._aug(a_x) - # n_x = self._aug(n_x) - - batch['x'], batch['x_targ'] = (a_x, p_x, n_x), (a_x_targ, p_x_targ, n_x_targ) - # compute! - return super().do_training_step(batch, batch_idx) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_weaklysupervised__st_adavae.py b/disent/frameworks/vae/experimental/_weaklysupervised__st_adavae.py deleted file mode 100644 index deddced4..00000000 --- a/disent/frameworks/vae/experimental/_weaklysupervised__st_adavae.py +++ /dev/null @@ -1,63 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass - -import numpy as np -from disent.frameworks.vae._weaklysupervised__adavae import AdaVae - - -# ========================================================================= # -# Swapped Target AdaVae # -# ========================================================================= # - - -class SwappedTargetAdaVae(AdaVae): - - REQUIRED_OBS = 2 - - @dataclass - class cfg(AdaVae.cfg): - swap_chance: float = 0.1 - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - assert cfg.swap_chance >= 0 - - def do_training_step(self, batch, batch_idx): - (x0, x1), (x0_targ, x1_targ) = self._get_xs_and_targs(batch, batch_idx) - - # random change for the target not to be equal to the input - if np.random.random() < self.cfg.swap_chance: - x0_targ, x1_targ = x1_targ, x0_targ - - return super(SwappedTargetAdaVae, self).do_training_step({ - 'x': (x0, x1), - 'x_targ': (x0_targ, x1_targ), - }, batch_idx) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/frameworks/vae/experimental/_weaklysupervised__st_betavae.py b/disent/frameworks/vae/experimental/_weaklysupervised__st_betavae.py deleted file mode 100644 index c3042059..00000000 --- a/disent/frameworks/vae/experimental/_weaklysupervised__st_betavae.py +++ /dev/null @@ -1,63 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from dataclasses import dataclass - -import numpy as np -from disent.frameworks.vae._unsupervised__betavae import BetaVae - - -# ========================================================================= # -# Swapped Target BetaVAE # -# ========================================================================= # - - -class SwappedTargetBetaVae(BetaVae): - - REQUIRED_OBS = 2 - - @dataclass - class cfg(BetaVae.cfg): - swap_chance: float = 0.1 - - def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): - super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) - assert cfg.swap_chance >= 0 - - def do_training_step(self, batch, batch_idx): - (x0, x1), (x0_targ, x1_targ) = self._get_xs_and_targs(batch, batch_idx) - - # random change for the target not to be equal to the input - if np.random.random() < self.cfg.swap_chance: - x0_targ, x1_targ = x1_targ, x0_targ - - return super(SwappedTargetBetaVae, self).do_training_step({ - 'x': (x0,), - 'x_targ': (x0_targ,), - }, batch_idx) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/disent/metrics/__init__.py b/disent/metrics/__init__.py index e4fb6784..98099e76 100644 --- a/disent/metrics/__init__.py +++ b/disent/metrics/__init__.py @@ -28,9 +28,6 @@ from ._mig import metric_mig from ._sap import metric_sap from ._unsupervised import metric_unsupervised -# Nathan Michlo et. al # pragma: delete-on-release -from ._flatness import metric_flatness # pragma: delete-on-release -from ._flatness_components import metric_flatness_components # pragma: delete-on-release # ========================================================================= # @@ -45,8 +42,6 @@ FAST_METRICS = { 'dci': _wrapped_partial(metric_dci, num_train=1000, num_test=500, boost_mode='sklearn'), 'factor_vae': _wrapped_partial(metric_factor_vae, num_train=700, num_eval=350, num_variance_estimate=1000), # may not be accurate, but it just takes waay too long otherwise 20+ seconds - 'flatness': _wrapped_partial(metric_flatness, factor_repeats=128), # pragma: delete-on-release - 'flatness_components': _wrapped_partial(metric_flatness_components, factor_repeats=128), # pragma: delete-on-release 'mig': _wrapped_partial(metric_mig, num_train=2000), 'sap': _wrapped_partial(metric_sap, num_train=2000, num_test=1000), 'unsupervised': _wrapped_partial(metric_unsupervised, num_train=2000), @@ -55,8 +50,6 @@ DEFAULT_METRICS = { 'dci': metric_dci, 'factor_vae': metric_factor_vae, - 'flatness': metric_flatness, # pragma: delete-on-release - 'flatness_components': metric_flatness_components, # pragma: delete-on-release 'mig': metric_mig, 'sap': metric_sap, 'unsupervised': metric_unsupervised, diff --git a/disent/metrics/_flatness.py b/disent/metrics/_flatness.py deleted file mode 100644 index 1bdf05e4..00000000 --- a/disent/metrics/_flatness.py +++ /dev/null @@ -1,347 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -Flatness Metric -- Nathan Michlo 2021 (Unpublished) -- Cite disent -""" - -import logging -import math -from typing import Iterable -from typing import Union - -import torch -from disent.util.deprecate import deprecated -from torch.utils.data.dataloader import default_collate - -from disent.dataset import DisentDataset -from disent.util.iters import iter_chunks - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# flatness # -# ========================================================================= # - - -@deprecated('flatness metric is deprecated in favour of flatness_components, this metric still gives useful alternative info however.') -def metric_flatness( - dataset: DisentDataset, - representation_function: callable, - factor_repeats: int = 1024, - batch_size: int = 64, -): - """ - Computes the flatness metric: - approximately equal to: total_dim_width / (ave_point_dist_along_dim * num_points_along_dim) - - Complexity of this metric is: - O(num_factors * ave_factor_size * repeats) - eg. 9 factors * 64 indices on ave * 128 repeats = 73728 observations loaded from the dataset - - factor_repeats: - - can go all the way down to about 64 and still get decent results. - - 64 is accurate to about +- 0.01 - - 128 is accurate to about +- 0.003 - - 1024 is accurate to about +- 0.001 - - Args: - dataset: DisentDataset to be sampled from. - representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. - factor_repeats: how many times to repeat a traversal along each factors, these are then averaged together. - batch_size: Batch size to process at any time while generating representations, should not effect metric results. - p: how to calculate distances in the latent space, see torch.norm - Returns: - Dictionary with average disentanglement score, completeness and - informativeness (train and test). - """ - p_fs_measures = aggregate_measure_distances_along_all_factors(dataset, representation_function, repeats=factor_repeats, batch_size=batch_size, ps=(1, 2)) - # get info - factor_sizes = dataset.gt_data.factor_sizes - # aggregate data - results = { - 'flatness.ave_flatness': compute_flatness(widths=p_fs_measures[2]['fs_ave_widths'], lengths=p_fs_measures[1]['fs_ave_lengths'], factor_sizes=factor_sizes), - 'flatness.ave_flatness_l1': compute_flatness(widths=p_fs_measures[1]['fs_ave_widths'], lengths=p_fs_measures[1]['fs_ave_lengths'], factor_sizes=factor_sizes), - 'flatness.ave_flatness_l2': compute_flatness(widths=p_fs_measures[2]['fs_ave_widths'], lengths=p_fs_measures[2]['fs_ave_lengths'], factor_sizes=factor_sizes), - # distances - 'flatness.ave_width_l1': torch.mean(filter_inactive_factors(p_fs_measures[1]['fs_ave_widths'], factor_sizes=factor_sizes)), - 'flatness.ave_width_l2': torch.mean(filter_inactive_factors(p_fs_measures[2]['fs_ave_widths'], factor_sizes=factor_sizes)), - 'flatness.ave_length_l1': torch.mean(filter_inactive_factors(p_fs_measures[1]['fs_ave_lengths'], factor_sizes=factor_sizes)), - 'flatness.ave_length_l2': torch.mean(filter_inactive_factors(p_fs_measures[2]['fs_ave_lengths'], factor_sizes=factor_sizes)), - # angles - 'flatness.cosine_angles': (1 / math.pi) * torch.mean(filter_inactive_factors(p_fs_measures[1]['fs_ave_angles'], factor_sizes=factor_sizes)), - } - # convert values from torch - return {k: float(v) for k, v in results.items()} - - -def compute_flatness(widths, lengths, factor_sizes): - widths = filter_inactive_factors(widths, factor_sizes) - lengths = filter_inactive_factors(lengths, factor_sizes) - # checks - assert torch.all(widths >= 0) - assert torch.all(lengths >= 0) - assert torch.all(torch.eq(widths == 0, lengths == 0)) - # update scores - widths[lengths == 0] = 0 - lengths[lengths == 0] = 1 - # compute flatness - return (widths / lengths).mean() - - -def filter_inactive_factors(tensor, factor_sizes): - factor_sizes = torch.tensor(factor_sizes, device=tensor.device) - assert torch.all(factor_sizes >= 1) - # remove - active_factors = torch.nonzero(factor_sizes-1, as_tuple=True) - return tensor[active_factors] - - -def aggregate_measure_distances_along_all_factors( - dataset: DisentDataset, - representation_function, - repeats: int, - batch_size: int, - ps: Iterable[Union[str, int]] = (1, 2), -) -> dict: - # COMPUTE AGGREGATES FOR EACH FACTOR - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - fs_p_measures = [ - aggregate_measure_distances_along_factor(dataset, representation_function, f_idx=f_idx, repeats=repeats, batch_size=batch_size, ps=ps) - for f_idx in range(dataset.gt_data.num_factors) - ] - - # FINALIZE FOR EACH FACTOR - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - p_fs_measures = {} - for p, fs_measures in default_collate(fs_p_measures).items(): - fs_ave_widths = fs_measures['ave_width'] - # get number of spaces deltas (number of points minus 1) - # compute length: estimated version of factors_ave_width = factors_num_deltas * factors_ave_delta - _fs_num_deltas = torch.as_tensor(dataset.gt_data.factor_sizes, device=fs_ave_widths.device) - 1 - _fs_ave_deltas = fs_measures['ave_delta'] - fs_ave_lengths = _fs_num_deltas * _fs_ave_deltas - # angles - fs_ave_angles = fs_measures['ave_angle'] - # update - p_fs_measures[p] = {'fs_ave_widths': fs_ave_widths, 'fs_ave_lengths': fs_ave_lengths, 'fs_ave_angles': fs_ave_angles} - return p_fs_measures - - -def aggregate_measure_distances_along_factor( - dataset: DisentDataset, - representation_function, - f_idx: int, - repeats: int, - batch_size: int, - ps: Iterable[Union[str, int]] = (1, 2), - cycle_fail: bool = False, -) -> dict: - f_size = dataset.gt_data.factor_sizes[f_idx] - - if f_size == 1: - if cycle_fail: - raise ValueError(f'dataset factor size is too small for flatness metric with cycle_normalize enabled! size={f_size} < 2') - zero = torch.as_tensor(0., device=get_device(dataset, representation_function)) - return {p: {'ave_width': zero.clone(), 'ave_delta': zero.clone(), 'ave_angle': zero.clone()} for p in ps} - - # FEED FORWARD, COMPUTE ALL DELTAS & WIDTHS - For each distance measure - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - p_measures: list = [{} for _ in range(repeats)] - for measures in p_measures: - # generate repeated factors, varying one factor over the entire range - zs_traversal = encode_all_along_factor(dataset, representation_function, f_idx=f_idx, batch_size=batch_size) - # for each distance measure compute everything - # - width: calculate the distance between the furthest two points - # - deltas: calculating the distances of their representations to the next values. - # - cycle_normalize: we cant get the ave next dist directly because of cycles, so we remove the largest dist - for p in ps: - deltas_next = torch.norm(torch.roll(zs_traversal, -1, dims=0) - zs_traversal, dim=-1, p=p) # next | shape: (factor_size, z_size) - deltas_prev = torch.norm(torch.roll(zs_traversal, 1, dims=0) - zs_traversal, dim=-1, p=p) # prev | shape: (factor_size, z_size) - # values needed for flatness - width = knn(x=zs_traversal, y=zs_traversal, k=1, largest=True, p=p).values.max() # shape: (,) - min_deltas = torch.topk(deltas_next, k=f_size-1, dim=-1, largest=False, sorted=False) # shape: (factor_size-1, z_size) - # values needed for cosine angles - # TODO: this should not be calculated per p - # TODO: should we filter the cyclic value? - # a. if the point is an endpoint we set its value to pi indicating that it is flat - # b. [THIS] we do not allow less than 3 points, ie. a factor_size of at least 3, otherwise - # we set the angle to pi (considered flat) and filter the factor from the metric - angles = angles_between(deltas_next, deltas_prev, dim=-1, nan_to_angle=0) # shape: (factor_size,) - # TODO: other measures can be added: - # 1. multivariate skewness - # 2. normality measure - # 3. independence - # 4. menger curvature (Cayley-Menger Determinant?) - # save variables - measures[p] = {'widths': width, 'deltas': min_deltas.values, 'angles': angles} - - # AGGREGATE DATA - For each distance measure - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - return { - p: { - 'ave_width': measures['widths'].mean(dim=0), # shape: (repeats,) -> () - 'ave_delta': measures['deltas'].mean(dim=[0, 1]), # shape: (repeats, factor_size - 1) -> () - 'ave_angle': measures['angles'].mean(dim=0), # shape: (repeats,) -> () - } for p, measures in default_collate(p_measures).items() - } - - -# ========================================================================= # -# ENCODE # -# ========================================================================= # - - -def encode_all_along_factor(dataset: DisentDataset, representation_function, f_idx: int, batch_size: int): - # generate repeated factors, varying one factor over a range (f_size, f_dims) - factors = dataset.gt_data.sample_random_factor_traversal(f_idx=f_idx) - # get the representations of all the factors (f_size, z_size) - sequential_zs = encode_all_factors(dataset, representation_function, factors=factors, batch_size=batch_size) - return sequential_zs - - -def encode_all_factors(dataset: DisentDataset, representation_function, factors, batch_size: int) -> torch.Tensor: - zs = [] - with torch.no_grad(): - for batch_factors in iter_chunks(factors, chunk_size=batch_size): - batch = dataset.dataset_batch_from_factors(batch_factors, mode='input') - z = representation_function(batch) - zs.append(z) - return torch.cat(zs, dim=0) - - -def get_device(dataset: DisentDataset, representation_function): - # this is a hack... - return representation_function(dataset.dataset_sample_batch(1, mode='input')).device - - -# ========================================================================= # -# DISTANCES # -# ========================================================================= # - - -def knn(x, y, k: int = None, largest=False, p='fro'): - assert 0 < k <= y.shape[0] - # check input vectors, must be array of vectors - assert 2 == x.ndim == y.ndim - assert x.shape[1:] == y.shape[1:] - # compute distances between each and every pair - dist_mat = x[:, None, ...] - y[None, :, ...] - dist_mat = torch.norm(dist_mat, dim=-1, p=p) - # return closest distances - return torch.topk(dist_mat, k=k, dim=-1, largest=largest, sorted=True) - - -# ========================================================================= # -# ANGLES # -# ========================================================================= # - - -def angles_between(a, b, dim=-1, nan_to_angle=None): - a = a / torch.norm(a, dim=dim, keepdim=True) - b = b / torch.norm(b, dim=dim, keepdim=True) - dot = torch.sum(a * b, dim=dim) - angles = torch.acos(torch.clamp(dot, -1.0, 1.0)) - if nan_to_angle is not None: - return torch.where(torch.isnan(angles), torch.full_like(angles, fill_value=nan_to_angle), angles) - return angles - - -# ========================================================================= # -# END # -# ========================================================================= # - - -# if __name__ == '__main__': -# import pytorch_lightning as pl -# from torch.optim import Adam -# from torch.utils.data import DataLoader -# from disent.data.groundtruth import XYObjectData, XYSquaresData -# from disent.dataset.groundtruth import GroundTruthDataset, GroundTruthDatasetPairs -# from disent.frameworks.vae import BetaVae -# from disent.frameworks.vae import AdaVae -# from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -# from disent.transform import ToImgTensorF32 -# from disent.util import colors -# from disent.util import Timer -# -# def get_str(r): -# return ', '.join(f'{k}={v:6.4f}' for k, v in r.items()) -# -# def print_r(name, steps, result, clr=colors.lYLW, t: Timer = None): -# print(f'{clr}{name:<13} ({steps:>04}){f" {colors.GRY}[{t.pretty}]{clr}" if t else ""}: {get_str(result)}{colors.RST}') -# -# def calculate(name, steps, dataset, get_repr): -# global aggregate_measure_distances_along_factor -# with Timer() as t: -# r = metric_flatness(dataset, get_repr, factor_repeats=64, batch_size=64) -# results.append((name, steps, r)) -# print_r(name, steps, r, colors.lRED, t=t) -# print(colors.GRY, '='*100, colors.RST, sep='') -# return r -# -# class XYOverlapData(XYSquaresData): -# def __init__(self, square_size=8, image_size=64, grid_spacing=None, num_squares=3, rgb=True): -# if grid_spacing is None: -# grid_spacing = (square_size+1) // 2 -# super().__init__(square_size=square_size, image_size=image_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) -# -# # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] -# datasets = [XYObjectData()] -# -# results = [] -# for data in datasets: -# dataset = GroundTruthDatasetPairs(data, transform=ToImgTensorF32()) -# dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True) -# module = AdaVae( -# model=AutoEncoder( -# encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), -# decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), -# ), -# cfg=AdaVae.cfg(beta=0.001, loss_reduction='mean', optimizer=torch.optim.Adam, optimizer_kwargs=dict(lr=5e-4)) -# ) -# # we cannot guarantee which device the representation is on -# get_repr = lambda x: module.encode(x.to(module.device)) -# # PHASE 1, UNTRAINED -# pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=True, gpus=1, weights_summary=None).fit(module, dataloader) -# module = module.to('cuda') -# calculate(data.__class__.__name__, 0, dataset, get_repr) -# # PHASE 2, LITTLE TRAINING -# pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, gpus=1, weights_summary=None).fit(module, dataloader) -# calculate(data.__class__.__name__, 256, dataset, get_repr) -# # PHASE 3, MORE TRAINING -# pl.Trainer(logger=False, checkpoint_callback=False, max_steps=2048, gpus=1, weights_summary=None).fit(module, dataloader) -# calculate(data.__class__.__name__, 256+2048, dataset, get_repr) -# results.append(None) -# -# for result in results: -# if result is None: -# print() -# continue -# (name, steps, result) = result -# print_r(name, steps, result, colors.lYLW) diff --git a/disent/metrics/_flatness_components.py b/disent/metrics/_flatness_components.py deleted file mode 100644 index 37798319..00000000 --- a/disent/metrics/_flatness_components.py +++ /dev/null @@ -1,412 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -Flatness Metric Components -- Nathan Michlo 2021 (Unpublished) -- Cite disent -""" - -import logging - -import numpy as np -import torch -from torch.utils.data.dataloader import default_collate - -from disent.dataset import DisentDataset -from disent.metrics._flatness import encode_all_along_factor -from disent.metrics._flatness import encode_all_factors -from disent.metrics._flatness import filter_inactive_factors -from disent.util.iters import iter_chunks -from disent.util import to_numpy -from disent.nn.functional import torch_mean_generalized -from disent.nn.functional import torch_pca - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# flatness # -# ========================================================================= # - - -def metric_flatness_components( - dataset: DisentDataset, - representation_function: callable, - factor_repeats: int = 1024, - batch_size: int = 64, -): - """ - Computes the flatness metric components (ordering, linearity & axis alignment): - global_swap_ratio: how swapped embeddings are compared to ground truth factors - factor_swap_ratio_near: how swapped embeddings are compared to ground truth factors - factor_swap_ratio: how swapped embeddings are compared to ground truth factors - axis_ratio: largest singular values over sum of singular values - ave_axis_ratio: largest singular values over sum of singular values - linear_ratio: largest std/variance over sum of std/variance - ave_linear_ratio: largest std/variance over sum of std/variance - axis_alignment: axis ratio is bounded by linear ratio - compute: axis / linear - ave_axis_alignment: axis ratio is bounded by linear ratio - compute: axis / linear - - Args: - dataset: DisentDataset to be sampled from. - representation_function: Function that takes observations as input and outputs a dim_representation sized representation for each observation. - factor_repeats: how many times to repeat a traversal along each factors, these are then averaged together. - batch_size: Batch size to process at any time while generating representations, should not effect metric results. - Returns: - Dictionary with metrics - """ - fs_measures, ran_measures = aggregate_measure_distances_along_all_factors(dataset, representation_function, repeats=factor_repeats, batch_size=batch_size) - - results = {} - for k, v in fs_measures.items(): - results[f'flatness_components.{k}'] = float(filtered_mean(v, p='geometric', factor_sizes=dataset.gt_data.factor_sizes)) - for k, v in ran_measures.items(): - results[f'flatness_components.{k}'] = float(v.mean(dim=0)) - - # convert values from torch - return results - - -def filtered_mean(values, p, factor_sizes): - # increase precision - values = values.to(torch.float64) - # check size - assert values.shape == (len(factor_sizes),) - # filter - values = filter_inactive_factors(values, factor_sizes) - # compute mean - mean = torch_mean_generalized(values, dim=0, p=p) - # return decreased precision - return to_numpy(mean.to(torch.float32)) - - -def aggregate_measure_distances_along_all_factors( - dataset: DisentDataset, - representation_function, - repeats: int, - batch_size: int, -) -> (dict, dict): - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # COMPUTE AGGREGATES FOR EACH FACTOR - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - fs_measures = default_collate([ - aggregate_measure_distances_along_factor(dataset, representation_function, f_idx=f_idx, repeats=repeats, batch_size=batch_size) - for f_idx in range(dataset.gt_data.num_factors) - ]) - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # COMPUTE RANDOM SWAP RATIO - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - values = [] - num_samples = int(np.mean(dataset.gt_data.factor_sizes) * repeats) - for idxs in iter_chunks(range(num_samples), batch_size): - # encode factors - factors = dataset.gt_data.sample_factors(size=len(idxs)) - zs = encode_all_factors(dataset, representation_function, factors, batch_size=batch_size) - # get random triplets from factors - rai, rpi, rni = np.random.randint(0, len(factors), size=(3, len(factors) * 4)) - rai, rpi, rni = reorder_by_factor_dist(factors, rai, rpi, rni) - # check differences - swap_ratio_l1, swap_ratio_l2 = compute_swap_ratios(zs[rai], zs[rpi], zs[rni]) - values.append({ - 'global_swap_ratio.l1': swap_ratio_l1, - 'global_swap_ratio.l2': swap_ratio_l2, - }) - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # RETURN - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - swap_measures = default_collate(values) - return fs_measures, swap_measures - - -# ========================================================================= # -# HELPER # -# ========================================================================= # - - -def reorder_by_factor_dist(factors, rai, rpi, rni): - a_fs, p_fs, n_fs = factors[rai], factors[rpi], factors[rni] - # sort all - d_ap = np.linalg.norm(a_fs - p_fs, ord=1, axis=-1) - d_an = np.linalg.norm(a_fs - n_fs, ord=1, axis=-1) - # swap - swap_mask = d_ap <= d_an - rpi_NEW = np.where(swap_mask, rpi, rni) - rni_NEW = np.where(swap_mask, rni, rpi) - # return new - return rai, rpi_NEW, rni_NEW - - -def compute_swap_ratios(a_zs, p_zs, n_zs): - ap_delta_l1, an_delta_l1 = torch.norm(a_zs - p_zs, dim=-1, p=1), torch.norm(a_zs - n_zs, dim=-1, p=1) - ap_delta_l2, an_delta_l2 = torch.norm(a_zs - p_zs, dim=-1, p=2), torch.norm(a_zs - n_zs, dim=-1, p=2) - swap_ratio_l1 = (ap_delta_l1 <= an_delta_l1).to(torch.float32).mean() - swap_ratio_l2 = (ap_delta_l2 <= an_delta_l2).to(torch.float32).mean() - return swap_ratio_l1, swap_ratio_l2 - - -# ========================================================================= # -# CORE # -# -- using variance instead of standard deviation makes it easier to # -# obtain high scores. # -# ========================================================================= # - - -def compute_unsorted_axis_values(zs_traversal, use_std: bool = True): - # CORRELATIONS -- SORTED IN DESCENDING ORDER: - # correlation with standard basis (1, 0, 0, ...), (0, 1, 0, ...), ... - axis_values = torch.var(zs_traversal, dim=0) # (z_size,) - if use_std: - axis_values = torch.sqrt(axis_values) - return axis_values - - -def compute_unsorted_linear_values(zs_traversal, use_std: bool = True): - # CORRELATIONS -- SORTED IN DESCENDING ORDER: - # correlation along arbitrary orthogonal basis - _, linear_values = torch_pca(zs_traversal, center=True, mode='svd') # svd: (min(z_size, factor_size),) | eig: (z_size,) - if use_std: - linear_values = torch.sqrt(linear_values) - return linear_values - - -def _score_from_sorted(sorted_vars: torch.Tensor, use_max: bool = False, norm: bool = True) -> torch.Tensor: - if use_max: - # use two max values - n = 2 - r = sorted_vars[0] / (sorted_vars[0] + torch.max(sorted_vars[1:])) - else: - # sum all values - n = len(sorted_vars) - r = sorted_vars[0] / torch.sum(sorted_vars) - # get norm if needed - if norm: - # for: x/(x+a) - # normalised = (x/(x+a) - (1/n)) / (1 - (1/n)) - # normalised = (x - 1/(n-1) * a) / (x + a) - r = (r - (1/n)) / (1 - (1/n)) - # done! - return r - - -def score_from_unsorted(unsorted_values: torch.Tensor, use_max: bool = False, norm: bool = True): - # sort in descending order - sorted_values = torch.sort(unsorted_values, descending=True).values - # compute score - return _score_from_sorted(sorted_values, use_max=use_max, norm=norm) - - -def compute_axis_score(zs_traversal: torch.Tensor, use_std: bool = True, use_max: bool = False, norm: bool = True): - return score_from_unsorted(compute_unsorted_axis_values(zs_traversal, use_std=use_std), use_max=use_max, norm=norm) - - -def compute_linear_score(zs_traversal: torch.Tensor, use_std: bool = True, use_max: bool = False, norm: bool = True): - return score_from_unsorted(compute_unsorted_linear_values(zs_traversal, use_std=use_std), use_max=use_max, norm=norm) - - -# ========================================================================= # -# TRAVERSAL FLATNESS # -# ========================================================================= # - - -def aggregate_measure_distances_along_factor( - ground_truth_dataset: DisentDataset, - representation_function, - f_idx: int, - repeats: int, - batch_size: int, -) -> dict: - # NOTE: this returns nan for all values if the factor size is 1 - - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # FEED FORWARD, COMPUTE ALL - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - measures = [] - for i in range(repeats): - # ENCODE TRAVERSAL: - # generate repeated factors, varying one factor over the entire range - zs_traversal = encode_all_along_factor(ground_truth_dataset, representation_function, f_idx=f_idx, batch_size=batch_size) - - # SWAP RATIO: - idxs_a, idxs_p_OLD, idxs_n_OLD = torch.randint(0, len(zs_traversal), size=(3, len(zs_traversal)*2)) - idx_mask = torch.abs(idxs_a - idxs_p_OLD) <= torch.abs(idxs_a - idxs_n_OLD) - idxs_p = torch.where(idx_mask, idxs_p_OLD, idxs_n_OLD) - idxs_n = torch.where(idx_mask, idxs_n_OLD, idxs_p_OLD) - # check the number of swapped elements along a factor - near_swap_ratio_l1, near_swap_ratio_l2 = compute_swap_ratios(zs_traversal[:-2], zs_traversal[1:-1], zs_traversal[2:]) - factor_swap_ratio_l1, factor_swap_ratio_l2 = compute_swap_ratios(zs_traversal[idxs_a, :], zs_traversal[idxs_p, :], zs_traversal[idxs_n, :]) - - # AXIS ALIGNMENT & LINEAR SCORES - # correlation with standard basis (1, 0, 0, ...), (0, 1, 0, ...), ... - axis_values_std = compute_unsorted_axis_values(zs_traversal, use_std=True) - axis_values_var = compute_unsorted_axis_values(zs_traversal, use_std=False) - # correlation along arbitrary orthogonal basis - linear_values_std = compute_unsorted_linear_values(zs_traversal, use_std=True) - linear_values_var = compute_unsorted_linear_values(zs_traversal, use_std=False) - - # compute scores - axis_ratio_std = score_from_unsorted(axis_values_std, use_max=False, norm=True) - axis_ratio_var = score_from_unsorted(axis_values_var, use_max=False, norm=True) - linear_ratio_std = score_from_unsorted(linear_values_std, use_max=False, norm=True) - linear_ratio_var = score_from_unsorted(linear_values_var, use_max=False, norm=True) - - # save variables - measures.append({ - 'factor_swap_ratio_near.l1': near_swap_ratio_l1, - 'factor_swap_ratio_near.l2': near_swap_ratio_l2, - 'factor_swap_ratio.l1': factor_swap_ratio_l1, - 'factor_swap_ratio.l2': factor_swap_ratio_l2, - # axis ratios - '_axis_values.std': axis_values_std, - '_axis_values.var': axis_values_var, - 'axis_ratio.std': axis_ratio_std, - 'axis_ratio.var': axis_ratio_var, - # linear ratios - '_linear_values.std': linear_values_std, - '_linear_values.var': linear_values_var, - 'linear_ratio.std': linear_ratio_std, - 'linear_ratio.var': linear_ratio_var, - # normalised axis alignment scores (axis_ratio is bounded by linear_ratio) - 'axis_alignment.std': axis_ratio_std / (linear_ratio_std + 1e-20), - 'axis_alignment.var': axis_ratio_var / (linear_ratio_var + 1e-20), - }) - - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - # AGGREGATE DATA - For each distance measure - # -~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~- # - measures = default_collate(measures) - - # aggregate over first dimension - results = {k: v.mean(dim=0) for k, v in measures.items()} - - # compute average scores & remove keys - results['ave_axis_ratio.std'] = score_from_unsorted(results.pop('_axis_values.std'), use_max=False, norm=True) - results['ave_axis_ratio.var'] = score_from_unsorted(results.pop('_axis_values.var'), use_max=False, norm=True) - results['ave_linear_ratio.std'] = score_from_unsorted(results.pop('_linear_values.std'), use_max=False, norm=True) - results['ave_linear_ratio.var'] = score_from_unsorted(results.pop('_linear_values.var'), use_max=False, norm=True) - # ave normalised axis alignment scores (axis_ratio is bounded by linear_ratio) - results['ave_axis_alignment.std'] = results['ave_axis_ratio.std'] / (results['ave_linear_ratio.std'] + 1e-20) - results['ave_axis_alignment.var'] = results['ave_axis_ratio.var'] / (results['ave_linear_ratio.var'] + 1e-20) - - return results - - -# ========================================================================= # -# END # -# ========================================================================= # - - -# if __name__ == '__main__': -# from disent.metrics import metric_flatness -# from sklearn import linear_model -# from disent.dataset.groundtruth import GroundTruthDatasetTriples -# from disent.dataset.groundtruth import GroundTruthDistDataset -# from disent.metrics._flatness import get_device -# import pytorch_lightning as pl -# from torch.optim import Adam -# from torch.utils.data import DataLoader -# from disent.data.groundtruth import XYObjectData, XYSquaresData -# from disent.dataset.groundtruth import GroundTruthDataset, GroundTruthDatasetPairs -# from disent.frameworks.vae import BetaVae -# from disent.frameworks.vae import AdaVae -# from disent.frameworks.vae import TripletVae -# from disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder -# from disent.transform import ToImgTensorF32 -# from disent.util import colors -# from disent.util import Timer -# -# def get_str(r): -# return ', '.join(f'{k}={v:6.4f}' for k, v in r.items()) -# -# def print_r(name, steps, result, clr=colors.lYLW, t: Timer = None): -# print(f'{clr}{name:<13} ({steps:>04}){f" {colors.GRY}[{t.pretty}]{clr}" if t else ""}: {get_str(result)}{colors.RST}') -# -# def calculate(name, steps, dataset, get_repr): -# global aggregate_measure_distances_along_factor -# with Timer() as t: -# r = { -# **metric_flatness_components(dataset, get_repr, factor_repeats=64, batch_size=64), -# **metric_flatness(dataset, get_repr, factor_repeats=64, batch_size=64), -# } -# results.append((name, steps, r)) -# print_r(name, steps, r, colors.lRED, t=t) -# print(colors.GRY, '='*100, colors.RST, sep='') -# return r -# -# class XYOverlapData(XYSquaresData): -# def __init__(self, square_size=8, image_size=64, grid_spacing=None, num_squares=3, rgb=True): -# if grid_spacing is None: -# grid_spacing = (square_size+1) // 2 -# super().__init__(square_size=square_size, image_size=image_size, grid_spacing=grid_spacing, num_squares=num_squares, rgb=rgb) -# -# # datasets = [XYObjectData(rgb=False, palette='white'), XYSquaresData(), XYOverlapData(), XYObjectData()] -# datasets = [XYObjectData()] -# -# # TODO: fix for dead dimensions -# # datasets = [XYObjectData(rgb=False, palette='white')] -# -# results = [] -# for data in datasets: -# -# # dataset = GroundTruthDistDataset(data, transform=ToImgTensorF32(), num_samples=2, triplet_sample_mode='manhattan') -# # dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True) -# # module = AdaVae( -# # model=AutoEncoder( -# # encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), -# # decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), -# # ), -# # cfg=AdaVae.cfg(beta=0.001, loss_reduction='mean', optimizer=torch.optim.Adam, optimizer_kwargs=dict(lr=5e-4)) -# # ) -# -# dataset = GroundTruthDistDataset(data, transform=ToImgTensorF32(), num_samples=3, triplet_sample_mode='manhattan') -# dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, pin_memory=True) -# module = TripletVae( -# model=AutoEncoder( -# encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), -# decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), -# ), -# cfg=TripletVae.cfg(beta=0.003, loss_reduction='mean', triplet_p=1, triplet_margin_max=10.0, triplet_scale=10.0, optimizer=torch.optim.Adam, optimizer_kwargs=dict(lr=5e-4)) -# ) -# -# # we cannot guarantee which device the representation is on -# get_repr = lambda x: module.encode(x.to(module.device)) -# # PHASE 1, UNTRAINED -# pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=True, gpus=1, weights_summary=None).fit(module, dataloader) -# module = module.to('cuda') -# calculate(data.__class__.__name__, 0, dataset, get_repr) -# # PHASE 2, LITTLE TRAINING -# pl.Trainer(logger=False, checkpoint_callback=False, max_steps=256, gpus=1, weights_summary=None).fit(module, dataloader) -# calculate(data.__class__.__name__, 256, dataset, get_repr) -# # PHASE 3, MORE TRAINING -# pl.Trainer(logger=False, checkpoint_callback=False, max_steps=2048, gpus=1, weights_summary=None).fit(module, dataloader) -# calculate(data.__class__.__name__, 256+2048, dataset, get_repr) -# results.append(None) -# -# for result in results: -# if result is None: -# print() -# continue -# (name, steps, result) = result -# print_r(name, steps, result, colors.lYLW) diff --git a/disent/registry/__init__.py b/disent/registry/__init__.py index d3fd9c9e..82814540 100644 --- a/disent/registry/__init__.py +++ b/disent/registry/__init__.py @@ -52,11 +52,7 @@ DATASETS['smallnorb'] = _LazyImport('disent.dataset.data._groundtruth__norb') DATASETS['shapes3d'] = _LazyImport('disent.dataset.data._groundtruth__shapes3d') # groundtruth -- impl synthetic -DATASETS['xyblocks'] = _LazyImport('disent.dataset.data._groundtruth__xyblocks') # pragma: delete-on-release DATASETS['xyobject'] = _LazyImport('disent.dataset.data._groundtruth__xyobject') -DATASETS['xysquares'] = _LazyImport('disent.dataset.data._groundtruth__xysquares') # pragma: delete-on-release -DATASETS['xysquares_minimal'] = _LazyImport('disent.dataset.data._groundtruth__xysquares') # pragma: delete-on-release -DATASETS['xcolumns'] = _LazyImport('disent.dataset.data._groundtruth__xcolumns') # pragma: delete-on-release # ========================================================================= # @@ -104,23 +100,6 @@ FRAMEWORKS['info_vae'] = _LazyImport('disent.frameworks.vae._unsupervised__infovae.InfoVae') FRAMEWORKS['vae'] = _LazyImport('disent.frameworks.vae._unsupervised__vae.Vae') FRAMEWORKS['ada_vae'] = _LazyImport('disent.frameworks.vae._weaklysupervised__adavae.AdaVae') -# [AE - EXPERIMENTAL] # pragma: delete-on-release -FRAMEWORKS['x__adaneg_tae'] = _LazyImport('disent.frameworks.ae.experimental._supervised__adaneg_tae.AdaNegTripletAe') # pragma: delete-on-release -FRAMEWORKS['x__dot_ae'] = _LazyImport('disent.frameworks.ae.experimental._unsupervised__dotae.DataOverlapTripletAe') # pragma: delete-on-release -FRAMEWORKS['x__ada_ae'] = _LazyImport('disent.frameworks.ae.experimental._weaklysupervised__adaae.AdaAe') # pragma: delete-on-release -# [VAE - EXPERIMENTAL] # pragma: delete-on-release -FRAMEWORKS['x__adaave_tvae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__adaave_tvae.AdaAveTripletVae') # pragma: delete-on-release -FRAMEWORKS['x__adaneg_tvae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__adaneg_tvae.AdaNegTripletVae') # pragma: delete-on-release -FRAMEWORKS['x__ada_tvae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__adatvae.AdaTripletVae') # pragma: delete-on-release -FRAMEWORKS['x__bada_vae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__badavae.BoundedAdaVae') # pragma: delete-on-release -FRAMEWORKS['x__gada_vae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__gadavae.GuidedAdaVae') # pragma: delete-on-release -FRAMEWORKS['x__tbada_vae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__tbadavae.TripletBoundedAdaVae') # pragma: delete-on-release -FRAMEWORKS['x__tgada_vae'] = _LazyImport('disent.frameworks.vae.experimental._supervised__tgadavae.TripletGuidedAdaVae') # pragma: delete-on-release -FRAMEWORKS['x__dor_vae'] = _LazyImport('disent.frameworks.vae.experimental._unsupervised__dorvae.DataOverlapRankVae') # pragma: delete-on-release -FRAMEWORKS['x__dot_vae'] = _LazyImport('disent.frameworks.vae.experimental._unsupervised__dotvae.DataOverlapTripletVae') # pragma: delete-on-release -FRAMEWORKS['x__augpos_tvae'] = _LazyImport('disent.frameworks.vae.experimental._weaklysupervised__augpostriplet.AugPosTripletVae') # pragma: delete-on-release -FRAMEWORKS['x__st_ada_vae'] = _LazyImport('disent.frameworks.vae.experimental._weaklysupervised__st_adavae.SwappedTargetAdaVae') # pragma: delete-on-release -FRAMEWORKS['x__st_beta_vae'] = _LazyImport('disent.frameworks.vae.experimental._weaklysupervised__st_betavae.SwappedTargetBetaVae') # pragma: delete-on-release # ========================================================================= # @@ -206,8 +185,6 @@ METRICS = _Registry('METRICS') METRICS['dci'] = _LazyImport('disent.metrics._dci.metric_dci') METRICS['factor_vae'] = _LazyImport('disent.metrics._factor_vae.metric_factor_vae') -METRICS['flatness'] = _LazyImport('disent.metrics._flatness.metric_flatness') # pragma: delete-on-release -METRICS['flatness_components'] = _LazyImport('disent.metrics._flatness_components.metric_flatness_components') # pragma: delete-on-release METRICS['mig'] = _LazyImport('disent.metrics._mig.metric_mig') METRICS['sap'] = _LazyImport('disent.metrics._sap.metric_sap') METRICS['unsupervised'] = _LazyImport('disent.metrics._unsupervised.metric_unsupervised') diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index fc2947cc..5f25a949 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -35,8 +35,6 @@ settings: framework_opt: latent_distribution: normal # only used by VAEs - overlap_loss: NULL # only used for experimental dotvae and dorvae # pragma: delete-on-release - usage_ratio: 0.5 # only used by adversarial masked datasets # pragma: delete-on-release model: z_size: 25 diff --git a/experiment/config/config_adversarial_dataset.yaml b/experiment/config/config_adversarial_dataset.yaml deleted file mode 100644 index f3f3ad23..00000000 --- a/experiment/config/config_adversarial_dataset.yaml +++ /dev/null @@ -1,60 +0,0 @@ - -# ========================================================================= # -# CONFIG # -# ========================================================================= # - - -defaults: - - run_logging: wandb_fast - - run_location: griffin - - run_launcher: local - # entries in this file override entries from default lists - - _self_ - -settings: - job: - user: 'n_michlo' - project: 'DELETE' # exp-disentangle-dataset - name: 'no-name' # TEST-${framework.dataset_name}_${framework.adversarial_mode}_${framework.sampler_name}_s${trainer.max_steps}_${framework.optimizer_name}_lr${framework.optimizer_lr} # _wd${framework.optimizer_kwargs.weight_decay} - seed: 777 - exp: - show_every_n_steps: 500 - # saving - rel_save_dir: 'out/adversarial_data/' - save_prefix: 'PREFIX' - save_data: TRUE - dataset: - batch_size: 32 - -trainer: - # same as defaults: - run_length: ... - max_steps: 30001 - max_epochs: 30001 - -adv_system: - ### IMPORTANT SETTINGS ### - dataset_name: 'dsprites' # [cars3d, smallnorb, dsprites, shapes3d, xysquares_8x8_mini] - adversarial_mode: 'self' # [self, invert_margin_0.005] invert, invert_unbounded - sampler_name: 'close_p_random_n' # [close_p_random_n, same_k1_close] - - ### OTHER SETTINGS ### - # optimizer options - optimizer_name: 'Adam' - optimizer_lr: 1e-1 - optimizer_kwargs: NULL - # dataset config options - # | dataset_name: 'cars3d' # cars3d, smallnorb, xysquares_8x8_mini - dataset_batch_size: 2048 # x3 - dataset_num_workers: ${dataloader.num_workers} - data_root: ${dsettings.storage.data_root} - # adversarial loss options - # | adversarial_mode: 'invert_margin_0.005' # [self, invert_margin_0.005] invert, invert_unbounded - adversarial_swapped: FALSE - adversarial_masking: FALSE # can produce weird artefacts that look like they might go against the training process, eg. improve disentanglement on dsprites, not actually checked by trianing model on this. - adversarial_top_k: NULL # NULL or range(1, batch_size) - pixel_loss_mode: 'mse' - # sampling config - # | sampler_name: 'close_p_random_n' # [close_p_random_n] (see notes above) -- close_p_random_n, close_p_random_n_bb, same_k, same_k_close, same_k1_close, same_k (might be wrong!), same_k_close, same_k1_close, close_far, close_factor_far_random, close_far_same_factor, same_factor, random_bb, random_swap_manhat, random_swap_manhat_norm - # train options - train_batch_optimizer: TRUE - train_dataset_fp16: TRUE diff --git a/experiment/config/config_adversarial_dataset_approx.yaml b/experiment/config/config_adversarial_dataset_approx.yaml deleted file mode 100644 index e984f7e6..00000000 --- a/experiment/config/config_adversarial_dataset_approx.yaml +++ /dev/null @@ -1,121 +0,0 @@ - -# ========================================================================= # -# CONFIG # -# ========================================================================= # - -defaults: - - run_logging: wandb_fast - - run_location: griffin - - run_launcher: local - # entries in this file override entries from default lists - - _self_ - -settings: - job: - user: 'n_michlo' - project: 'DELETE' - name_prefix: 'B32' - name: '${settings.job.name_prefix}-${adv_system.dataset_name}_${adv_system.adversarial_mode}_${adv_system.samples_sort_mode}_aw${adv_system.loss_adversarial_weight}_${adv_system.sampler_name}_s${trainer.max_steps}_${adv_system.optimizer_name}_lr${adv_system.optimizer_lr}_wd${adv_system.optimizer_kwargs.weight_decay}_b${settings.dataset.batch_size}_${settings.exp.save_dtype}' - seed: 424242 - exp: - show_every_n_steps: 1000 - # saving - rel_save_dir: 'out/adversarial_data_approx/' - save_prefix: 'PREFIX' - save_model: FALSE - save_data: FALSE - save_dtype: float16 - dataset: - batch_size: 32 - -trainer: - # same as defaults: - run_length: ... - # - 15000 takes 40 mins with batch size 512 (heartofgold, 12 workers) - # - 50000 takes 33 mins with batch size 256 (griffin, 16 workers) - max_steps: 15000 - max_epochs: 15000 - -adv_system: - ### IMPORTANT SETTINGS ### - # best: - # - close_p_random_n - # note: sampler_name (adversarial_mode=invert_margin_0.005) - # - random_swap_manhattan: worst [no inversion before 5k] (probability of encountering close is too low, don't use! ++easiest to implement) - # - close_p_random_n: good [inversion before 5k] (easier to implement) - # - close_p_random_n_bb: good [inversion before 5k] (hard to implement, but pretty much the same as close_p_random_n) - # - same_k: bad [no inversion before 5k] (probability of encountering close is too low, don't use! --harder to implement, better guarantees than random_swap_manhattan) - # - same_k_close: ok [almost inversion before 5k] (harder to implement) - # - same_k1_close: best [inversion well before 5 k] (easier to implement) - # note: sampler_name (adversarial_mode=self) - # - close_p_random_n: seems better based on plot of fdists vs overlap (converges better, but loss is higher which makes sense) - # - same_k1_close: seems worse based on plot of fdists vs overlap (seems to maintain original shape more, might hinder disentanglement? not actually tested) - sampler_name: 'close_p_random_n' # [random_swap_manhattan, close_p_random_n, same_k1_close] - samples_sort_mode: 'swap' # [none, swap, sort_inorder, sort_reverse] - dataset_name: 'smallnorb' # [cars3d, smallnorb, dsprites, shapes3d, xysquares_8x8_mini] - adversarial_mode: 'triplet_margin_0.1' # [self, invert_margin_0.05, invert_margin_0.005] invert, invert_unbounded - - ### OTHER SETTINGS ### - # optimizer options - optimizer_name: 'adam' - optimizer_lr: 2e-3 - optimizer_kwargs: - weight_decay: 1e-5 - # dataset config options - dataset_batch_size: ${dataloader.batch_size} # x3 - dataset_num_workers: ${dataloader.num_workers} - data_root: ${dsettings.storage.data_root} - data_load_into_memory: FALSE # I don't think this is truly multi-threaded, possible lock on array access? - # adversarial loss options - adversarial_swapped: FALSE - adversarial_masking: FALSE # can produce weird artefacts that look like they might go against the training process, eg. improve disentanglement on dsprites, not actually checked by trianing model on this. - adversarial_top_k: NULL # NULL or range(1, batch_size) - pixel_loss_mode: 'mse' - # loss extras - loss_adversarial_weight: 10.0 - loss_out_of_bounds_weight: 1.0 # not really needed -- if this is too high it struggles to "invert" - loss_same_stats_weight: 0.0 # not really needed - loss_similarity_weight: 1.0 # important - # model settings - model_type: 'ae_conv64' # ae_conv64, ae_linear, ae_conv64norm - model_mask_mode: 'none' # std, diff, none - model_weight_init: 'xavier_normal' # [xavier_normal, default] - # logging settings - logging_scale_imgs: FALSE - - -# ========================================================================= # -# OLD EXPERIMENTS # -# ========================================================================= # - - -# EXPERIMENT SWEEP: -# -m framework.sampler_name=close_p_random_n framework.adversarial_mode=self,invert_margin_0.005 framework.dataset_name=dsprites,shapes3d,cars3d,smallnorb -# -m framework.loss_adversarial_weight=100.0 framework.sampler_name=same_k1_close framework.adversarial_mode=self2,self framework.dataset_name=dsprites,shapes3d,cars3d,smallnorb - -# EXPERIMENT INDIVIDUALS: -# framework.sampler_name=close_p_random_n framework.adversarial_mode=self framework.dataset_name=dsprites -# framework.sampler_name=close_p_random_n framework.adversarial_mode=self framework.dataset_name=shapes3d -# framework.sampler_name=close_p_random_n framework.adversarial_mode=self framework.dataset_name=cars3d -# framework.sampler_name=close_p_random_n framework.adversarial_mode=self framework.dataset_name=smallnorb - -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.005 framework.dataset_name=dsprites -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.005 framework.dataset_name=shapes3d -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.005 framework.dataset_name=cars3d -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.005 framework.dataset_name=smallnorb -# -# # 3dshapes does not seem to want to invert... -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.01 framework.dataset_name=shapes3d -# framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.10 framework.dataset_name=shapes3d - -# NEW EXPERIMENT: -# -m framework.sampler_name=same_k1_close,close_p_random_n framework.adversarial_mode=invert_margin_0.05 framework.dataset_name=dsprites,shapes3d,smallnorb,cars3d -# - continue -# DONE: -m framework.sampler_name=same_k1_close,close_p_random_n framework.adversarial_mode=invert_margin_0.05 framework.dataset_name=smallnorb,cars3d -# DOING: -m framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.05 framework.dataset_name=smallnorb,cars3d -# TODO: -m framework.sampler_name=close_p_random_n framework.adversarial_mode=invert_margin_0.05 framework.dataset_name=cars3d,smallnorb - -# NEW EXPERIMENT 2: -# -m framework.sampler_name=same_k1_close,close_p_random_n framework.adversarial_mode=invert_margin_0.05 framework.loss_out_of_bounds_weight=1000.0 framework.dataset_name=dsprites,shapes3d,smallnorb,cars3d - -# NEW EXPERIMENT 3: -# -m framework.sampler_name=same_k1_close framework.adversarial_mode=invert_margin_0.05 framework.loss_out_of_bounds_weight=10000.0 framework.dataset_name=shapes3d,dsprites,cars3d,smallnorb diff --git a/experiment/config/config_adversarial_kernel.yaml b/experiment/config/config_adversarial_kernel.yaml deleted file mode 100644 index ea4020ec..00000000 --- a/experiment/config/config_adversarial_kernel.yaml +++ /dev/null @@ -1,50 +0,0 @@ -defaults: - # runtime - - run_length: short - - run_logging: wandb - - run_location: stampede_tmp - - run_launcher: slurm - # plugins - - hydra/job_logging: colorlog - - hydra/hydra_logging: colorlog - - hydra/launcher: submitit_slurm - -job: - user: 'n_michlo' - project: 'exp-disentangle-kernel' - name: r${kernel.radius}-${kernel.channels}_s${trainer.max_steps}_${optimizer.name}_lr${settings.optimizer.lr}_wd${optimizer.weight_decay}_${data.name} - -optimizer: - name: adam - lr: 3e-3 - weight_decay: 0.0 - -data: - name: 'xysquares_8x8' - -kernel: - radius: 63 - channels: 1 - disentangle_factors: NULL - # training - regularize_symmetric: TRUE - regularize_norm: FALSE # these don't work - regularize_nonneg: FALSE # these don't work - -train: - pairs_ratio: 8.0 - loss: mse - -exp: - seed: 777 - rel_save_dir: data/adversarial_kernel - save_name: ${job.name}.pt - show_every_n_steps: 1000 - -# OVERRIDE run_logging: wandb -- too fast otherwise -logging: - flush_logs_every_n_steps: 500 - -# OVERRIDE run_location: -dataset: - batch_size: 128 diff --git a/experiment/config/config_test.yaml b/experiment/config/config_test.yaml index e01a9353..63a66c8a 100644 --- a/experiment/config/config_test.yaml +++ b/experiment/config/config_test.yaml @@ -35,8 +35,6 @@ settings: framework_opt: latent_distribution: normal # only used by VAEs - overlap_loss: NULL # only used for experimental dotvae and dorvae # pragma: delete-on-release - usage_ratio: 0.5 # only used by adversarial masked datasets # pragma: delete-on-release model: z_size: 25 diff --git a/experiment/config/dataset/X--adv-cars3d--WARNING.yaml b/experiment/config/dataset/X--adv-cars3d--WARNING.yaml deleted file mode 100644 index ac8d26a5..00000000 --- a/experiment/config/dataset/X--adv-cars3d--WARNING.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - _data_type_: gt - -name: adv_cars3d - -data: - _target_: disent.dataset.data.SelfContainedHdf5GroundTruthData - h5_path: '${oc.env:HOME}/workspace/research/disent/out/adversarial_data_approx/2021-09-06--05-42-06_INVERT-VSTRONG-cars3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06/data.h5' - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.76418207, 0.75554032, 0.75075393] - vis_std: [0.31892905, 0.32751031, 0.33319886] - -# TODO: this does not yet copy the data to /tmp/ and thus if run on a cluster of a network drive, this will hammer the network disk. Fix this! diff --git a/experiment/config/dataset/X--adv-dsprites--WARNING.yaml b/experiment/config/dataset/X--adv-dsprites--WARNING.yaml deleted file mode 100644 index 3965bf84..00000000 --- a/experiment/config/dataset/X--adv-dsprites--WARNING.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - _data_type_: gt - -name: adv_dsprites - -data: - _target_: disent.dataset.data.SelfContainedHdf5GroundTruthData - h5_path: '${oc.env:HOME}/workspace/research/disent/out/adversarial_data_approx/2021-09-06--03-17-28_INVERT-VSTRONG-dsprites_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06/data.h5' - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.20482841] - vis_std: [0.33634909] - -# TODO: this does not yet copy the data to /tmp/ and thus if run on a cluster of a network drive, this will hammer the network disk. Fix this! diff --git a/experiment/config/dataset/X--adv-shapes3d--WARNING.yaml b/experiment/config/dataset/X--adv-shapes3d--WARNING.yaml deleted file mode 100644 index 5983845a..00000000 --- a/experiment/config/dataset/X--adv-shapes3d--WARNING.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - _data_type_: gt - -name: adv_shapes3d - -data: - _target_: disent.dataset.data.SelfContainedHdf5GroundTruthData - h5_path: '${oc.env:HOME}/workspace/research/disent/out/adversarial_data_approx/2021-09-06--00-29-23_INVERT-VSTRONG-shapes3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06/data.h5' - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.47992192, 0.51311111, 0.54627272] - vis_std: [0.28653814, 0.29201543, 0.27395435] - -# TODO: this does not yet copy the data to /tmp/ and thus if run on a cluster of a network drive, this will hammer the network disk. Fix this! diff --git a/experiment/config/dataset/X--adv-smallnorb--WARNING.yaml b/experiment/config/dataset/X--adv-smallnorb--WARNING.yaml deleted file mode 100644 index fa483e82..00000000 --- a/experiment/config/dataset/X--adv-smallnorb--WARNING.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - _data_type_: gt - -name: adv_smallnorb - -data: - _target_: disent.dataset.data.SelfContainedHdf5GroundTruthData - h5_path: '${oc.env:HOME}/workspace/research/disent/out/adversarial_data_approx/2021-09-06--09-10-59_INVERT-VSTRONG-smallnorb_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06/data.h5' - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.69691603] - vis_std: [0.21310608] - -# TODO: this does not yet copy the data to /tmp/ and thus if run on a cluster of a network drive, this will hammer the network disk. Fix this! diff --git a/experiment/config/dataset/X--dsprites-imagenet-bg-100.yaml b/experiment/config/dataset/X--dsprites-imagenet-bg-100.yaml deleted file mode 100644 index 1ab49d2c..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-bg-100.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_bg_100 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 100 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${data.meta.vis_mean} - std: ${data.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404] - vis_std: [0.2505510666843685, 0.25007259803668697, 0.2562415603123114] diff --git a/experiment/config/dataset/X--dsprites-imagenet-bg-20.yaml b/experiment/config/dataset/X--dsprites-imagenet-bg-20.yaml deleted file mode 100644 index 00aa4955..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-bg-20.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_bg_20 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 20 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.13294969414492142, 0.12694375140936273, 0.11733572285575933] - vis_std: [0.18311250427586276, 0.1840916474752131, 0.18607373519458442] diff --git a/experiment/config/dataset/X--dsprites-imagenet-bg-40.yaml b/experiment/config/dataset/X--dsprites-imagenet-bg-40.yaml deleted file mode 100644 index ad4674ee..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-bg-40.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_bg_40 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 40 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.2248598986983768, 0.21285772298967615, 0.19359577132944206] - vis_std: [0.1841631708032332, 0.18554895825833284, 0.1893568926398198] diff --git a/experiment/config/dataset/X--dsprites-imagenet-bg-60.yaml b/experiment/config/dataset/X--dsprites-imagenet-bg-60.yaml deleted file mode 100644 index 5a0f6550..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-bg-60.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_bg_60 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 60 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.31676960943447674, 0.29877166834408025, 0.2698556821388113] - vis_std: [0.19745897110349003, 0.1986606891520453, 0.203808842880044] diff --git a/experiment/config/dataset/X--dsprites-imagenet-bg-80.yaml b/experiment/config/dataset/X--dsprites-imagenet-bg-80.yaml deleted file mode 100644 index f699681e..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-bg-80.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_bg_80 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 80 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.40867981393820857, 0.38468564002021527, 0.34611573047508204] - vis_std: [0.22048328737091344, 0.22102216869942384, 0.22692977053753477] diff --git a/experiment/config/dataset/X--dsprites-imagenet-fg-100.yaml b/experiment/config/dataset/X--dsprites-imagenet-fg-100.yaml deleted file mode 100644 index 82202433..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-fg-100.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_fg_100 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 100 - mode: fg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384] - vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281] diff --git a/experiment/config/dataset/X--dsprites-imagenet-fg-20.yaml b/experiment/config/dataset/X--dsprites-imagenet-fg-20.yaml deleted file mode 100644 index df765265..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-fg-20.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_fg_20 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 20 - mode: fg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.038064750024334834, 0.03766780505193579, 0.03719798677641122] - vis_std: [0.17498878664096565, 0.17315570657628318, 0.1709923319496426] diff --git a/experiment/config/dataset/X--dsprites-imagenet-fg-40.yaml b/experiment/config/dataset/X--dsprites-imagenet-fg-40.yaml deleted file mode 100644 index 1d79f75d..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-fg-40.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_fg_40 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 40 - mode: fg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.03369999506331255, 0.03290657349801835, 0.03196482946320608] - vis_std: [0.155514074438101, 0.1518464537731621, 0.14750944591836743] diff --git a/experiment/config/dataset/X--dsprites-imagenet-fg-60.yaml b/experiment/config/dataset/X--dsprites-imagenet-fg-60.yaml deleted file mode 100644 index d65e3622..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-fg-60.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_fg_60 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 60 - mode: fg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.029335176871153983, 0.028145355435322966, 0.026731731769287146] - vis_std: [0.13663242436043319, 0.13114320478634894, 0.1246542727733097] diff --git a/experiment/config/dataset/X--dsprites-imagenet-fg-80.yaml b/experiment/config/dataset/X--dsprites-imagenet-fg-80.yaml deleted file mode 100644 index bb3c025c..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet-fg-80.yaml +++ /dev/null @@ -1,22 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_fg_80 - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 80 - mode: fg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.024956427531012196, 0.02336780403840578, 0.021475119672280243] - vis_std: [0.11864125016313823, 0.11137998105649799, 0.10281424917834255] diff --git a/experiment/config/dataset/X--dsprites-imagenet.yaml b/experiment/config/dataset/X--dsprites-imagenet.yaml deleted file mode 100644 index 6329d035..00000000 --- a/experiment/config/dataset/X--dsprites-imagenet.yaml +++ /dev/null @@ -1,54 +0,0 @@ -defaults: - - _data_type_: gt - -name: dsprites_imagenet_${dataset.mode}_${dataset.visibility} - -data: - _target_: disent.dataset.data.DSpritesImagenetData - visibility: 40 - mode: bg - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: ${exit:EXITING... dsprites-imagenet has been disabled} # ${dataset.__STATS.${dataset.name}.vis_mean} - vis_std: ${exit:EXITING... dsprites-imagenet has been disabled} # ${dataset.__STATS.${dataset.name}.vis_std} - -__STATS: - dsprites_imagenet_fg_100: - vis_mean: [0.02067051643494642, 0.018688392816012946, 0.01632900510079384] - vis_std: [0.10271307751834059, 0.09390213983525653, 0.08377594259970281] - dsprites_imagenet_fg_80: - vis_mean: [0.024956427531012196, 0.02336780403840578, 0.021475119672280243] - vis_std: [0.11864125016313823, 0.11137998105649799, 0.10281424917834255] - dsprites_imagenet_fg_60: - vis_mean: [0.029335176871153983, 0.028145355435322966, 0.026731731769287146] - vis_std: [0.13663242436043319, 0.13114320478634894, 0.1246542727733097] - dsprites_imagenet_fg_40: - vis_mean: [0.03369999506331255, 0.03290657349801835, 0.03196482946320608] - vis_std: [0.155514074438101, 0.1518464537731621, 0.14750944591836743] - dsprites_imagenet_fg_20: - vis_mean: [0.038064750024334834, 0.03766780505193579, 0.03719798677641122] - vis_std: [0.17498878664096565, 0.17315570657628318, 0.1709923319496426] - dsprites_imagenet_bg_100: - vis_mean: [0.5020433619489952, 0.47206398913310593, 0.42380018909780404] - vis_std: [0.2505510666843685, 0.25007259803668697, 0.2562415603123114] - dsprites_imagenet_bg_80: - vis_mean: [0.40867981393820857, 0.38468564002021527, 0.34611573047508204] - vis_std: [0.22048328737091344, 0.22102216869942384, 0.22692977053753477] - dsprites_imagenet_bg_60: - vis_mean: [0.31676960943447674, 0.29877166834408025, 0.2698556821388113] - vis_std: [0.19745897110349003, 0.1986606891520453, 0.203808842880044] - dsprites_imagenet_bg_40: - vis_mean: [0.2248598986983768, 0.21285772298967615, 0.19359577132944206] - vis_std: [0.1841631708032332, 0.18554895825833284, 0.1893568926398198] - dsprites_imagenet_bg_20: - vis_mean: [0.13294969414492142, 0.12694375140936273, 0.11733572285575933] - vis_std: [0.18311250427586276, 0.1840916474752131, 0.18607373519458442] diff --git a/experiment/config/dataset/X--mask-adv-f-cars3d.yaml b/experiment/config/dataset/X--mask-adv-f-cars3d.yaml deleted file mode 100644 index 848a271e..00000000 --- a/experiment/config/dataset/X--mask-adv-f-cars3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_f_cars3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-23-27_EXP_cars3d_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--22-58-24_EXP_cars3d_1000x256_all_std_gmean/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.Cars3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${data.meta.vis_mean} - std: ${data.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] - vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] diff --git a/experiment/config/dataset/X--mask-adv-f-dsprites.yaml b/experiment/config/dataset/X--mask-adv-f-dsprites.yaml deleted file mode 100644 index 5517a992..00000000 --- a/experiment/config/dataset/X--mask-adv-f-dsprites.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_f_dsprites - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-45-46_EXP_dsprites_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-21-51_EXP_dsprites_1000x256_all_std_gmean/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.DSpritesData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.042494423521889584] - vis_std: [0.19516645880626055] diff --git a/experiment/config/dataset/X--mask-adv-f-shapes3d.yaml b/experiment/config/dataset/X--mask-adv-f-shapes3d.yaml deleted file mode 100644 index 6871f130..00000000 --- a/experiment/config/dataset/X--mask-adv-f-shapes3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_f_shapes3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-33-57_EXP_shapes3d_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-09-05_EXP_shapes3d_1000x256_all_std_gmean/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.Shapes3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] - vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] diff --git a/experiment/config/dataset/X--mask-adv-f-smallnorb.yaml b/experiment/config/dataset/X--mask-adv-f-smallnorb.yaml deleted file mode 100644 index 738a8abf..00000000 --- a/experiment/config/dataset/X--mask-adv-f-smallnorb.yaml +++ /dev/null @@ -1,29 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_f_smallnorb - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-28-42_EXP_smallnorb_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-03-51_EXP_smallnorb_1000x256_all_std_gmean/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.SmallNorbData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - is_test: False - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${data.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.7520918401088603] - vis_std: [0.09563879016827262] diff --git a/experiment/config/dataset/X--mask-adv-r-cars3d.yaml b/experiment/config/dataset/X--mask-adv-r-cars3d.yaml deleted file mode 100644 index d7c64191..00000000 --- a/experiment/config/dataset/X--mask-adv-r-cars3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_r_cars3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--14-49-26_DISTS-SCALED_cars3d_1000x384_random_256_True_std_False/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--18-41-14_DISTS-SCALED_cars3d_1000x384_random_256_True_range_False/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.Cars3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] - vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] diff --git a/experiment/config/dataset/X--mask-adv-r-dsprites.yaml b/experiment/config/dataset/X--mask-adv-r-dsprites.yaml deleted file mode 100644 index 26a16f75..00000000 --- a/experiment/config/dataset/X--mask-adv-r-dsprites.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_r_dsprites - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--16-31-56_DISTS-SCALED_dsprites_1000x384_random_256_True_std_False/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--19-58-39_DISTS-SCALED_dsprites_1000x384_random_256_True_range_False/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.DSpritesData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.042494423521889584] - vis_std: [0.19516645880626055] diff --git a/experiment/config/dataset/X--mask-adv-r-shapes3d.yaml b/experiment/config/dataset/X--mask-adv-r-shapes3d.yaml deleted file mode 100644 index bc799876..00000000 --- a/experiment/config/dataset/X--mask-adv-r-shapes3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_r_shapes3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--15-20-48_DISTS-SCALED_shapes3d_1000x384_random_256_True_std_False/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--19-04-26_DISTS-SCALED_shapes3d_1000x384_random_256_True_range_False/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.Shapes3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] - vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] diff --git a/experiment/config/dataset/X--mask-adv-r-smallnorb.yaml b/experiment/config/dataset/X--mask-adv-r-smallnorb.yaml deleted file mode 100644 index b36c799d..00000000 --- a/experiment/config/dataset/X--mask-adv-r-smallnorb.yaml +++ /dev/null @@ -1,29 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_adv_r_smallnorb - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--15-10-07_DISTS-SCALED_smallnorb_1000x384_random_256_True_std_False/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-10-19--18-53-52_DISTS-SCALED_smallnorb_1000x384_random_256_True_range_False/data.pkl.gz' - randomize: FALSE - data: - _target_: disent.dataset.data.SmallNorbData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - is_test: False - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.7520918401088603] - vis_std: [0.09563879016827262] diff --git a/experiment/config/dataset/X--mask-dthr-cars3d.yaml b/experiment/config/dataset/X--mask-dthr-cars3d.yaml deleted file mode 100644 index c643a64f..00000000 --- a/experiment/config/dataset/X--mask-dthr-cars3d.yaml +++ /dev/null @@ -1,24 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_dthr_cars3d - -data: - _target_: disent.dataset.wrapper.DitheredDataset - dither_n: 2 - keep_ratio: ${settings.framework_opt.usage_ratio} - gt_data: - _target_: disent.dataset.data.Cars3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] - vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] diff --git a/experiment/config/dataset/X--mask-dthr-dsprites.yaml b/experiment/config/dataset/X--mask-dthr-dsprites.yaml deleted file mode 100644 index 03000f9b..00000000 --- a/experiment/config/dataset/X--mask-dthr-dsprites.yaml +++ /dev/null @@ -1,24 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_dthr_dsprites - -data: - _target_: disent.dataset.wrapper.DitheredDataset - dither_n: 2 - keep_ratio: ${settings.framework_opt.usage_ratio} - gt_data: - _target_: disent.dataset.data.DSpritesData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.042494423521889584] - vis_std: [0.19516645880626055] diff --git a/experiment/config/dataset/X--mask-dthr-shapes3d.yaml b/experiment/config/dataset/X--mask-dthr-shapes3d.yaml deleted file mode 100644 index 9aa229da..00000000 --- a/experiment/config/dataset/X--mask-dthr-shapes3d.yaml +++ /dev/null @@ -1,24 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_dthr_shapes3d - -data: - _target_: disent.dataset.wrapper.DitheredDataset - dither_n: 2 - keep_ratio: ${settings.framework_opt.usage_ratio} - gt_data: - _target_: disent.dataset.data.Shapes3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] - vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] diff --git a/experiment/config/dataset/X--mask-dthr-smallnorb.yaml b/experiment/config/dataset/X--mask-dthr-smallnorb.yaml deleted file mode 100644 index 28455e5f..00000000 --- a/experiment/config/dataset/X--mask-dthr-smallnorb.yaml +++ /dev/null @@ -1,25 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_dthr_smallnorb - -data: - _target_: disent.dataset.wrapper.DitheredDataset - dither_n: 2 - keep_ratio: ${settings.framework_opt.usage_ratio} - gt_data: - _target_: disent.dataset.data.SmallNorbData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - is_test: False - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.7520918401088603] - vis_std: [0.09563879016827262] diff --git a/experiment/config/dataset/X--mask-ran-cars3d.yaml b/experiment/config/dataset/X--mask-ran-cars3d.yaml deleted file mode 100644 index 59afd87e..00000000 --- a/experiment/config/dataset/X--mask-ran-cars3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_ran_cars3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-23-27_EXP_cars3d_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--22-58-24_EXP_cars3d_1000x256_all_std_gmean/data.pkl.gz' - randomize: TRUE - data: - _target_: disent.dataset.data.Cars3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] - vis_std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] diff --git a/experiment/config/dataset/X--mask-ran-dsprites.yaml b/experiment/config/dataset/X--mask-ran-dsprites.yaml deleted file mode 100644 index a9a1836a..00000000 --- a/experiment/config/dataset/X--mask-ran-dsprites.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_ran_dsprites - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-45-46_EXP_dsprites_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-21-51_EXP_dsprites_1000x256_all_std_gmean/data.pkl.gz' - randomize: TRUE - data: - _target_: disent.dataset.data.DSpritesData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.042494423521889584] - vis_std: [0.19516645880626055] diff --git a/experiment/config/dataset/X--mask-ran-shapes3d.yaml b/experiment/config/dataset/X--mask-ran-shapes3d.yaml deleted file mode 100644 index c55396f5..00000000 --- a/experiment/config/dataset/X--mask-ran-shapes3d.yaml +++ /dev/null @@ -1,28 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_ran_shapes3d - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-33-57_EXP_shapes3d_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-09-05_EXP_shapes3d_1000x256_all_std_gmean/data.pkl.gz' - randomize: TRUE - data: - _target_: disent.dataset.data.Shapes3dData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - in_memory: ${dsettings.dataset.try_in_memory} - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] - vis_std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] diff --git a/experiment/config/dataset/X--mask-ran-smallnorb.yaml b/experiment/config/dataset/X--mask-ran-smallnorb.yaml deleted file mode 100644 index f8d7267e..00000000 --- a/experiment/config/dataset/X--mask-ran-smallnorb.yaml +++ /dev/null @@ -1,29 +0,0 @@ -defaults: - - _data_type_: random - -name: mask_ran_smallnorb - -data: - _target_: disent.dataset.wrapper.MaskedDataset - mask: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: ${settings.framework_opt.usage_ratio} - # pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--21-28-42_EXP_smallnorb_1000x256_all_std_mean/data.pkl.gz' - pickle_file: '${oc.env:HOME}/workspace/research/disent/out/adversarial_mask/2021-09-27--23-03-51_EXP_smallnorb_1000x256_all_std_gmean/data.pkl.gz' - randomize: TRUE - data: - _target_: disent.dataset.data.SmallNorbData - data_root: ${dsettings.storage.data_root} - prepare: ${dsettings.dataset.prepare} - is_test: False - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - size: 64 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.7520918401088603] - vis_std: [0.09563879016827262] diff --git a/experiment/config/dataset/X--xyblocks.yaml b/experiment/config/dataset/X--xyblocks.yaml deleted file mode 100644 index 5eaf260d..00000000 --- a/experiment/config/dataset/X--xyblocks.yaml +++ /dev/null @@ -1,18 +0,0 @@ -defaults: - - _data_type_: gt - -name: xyblocks - -data: - _target_: disent.dataset.data.XYBlocksData - rgb: TRUE - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.10040509259259259, 0.10040509259259259, 0.10040509259259259] - vis_std: [0.21689087652106678, 0.21689087652106676, 0.21689087652106678] diff --git a/experiment/config/dataset/X--xyblocks_grey.yaml b/experiment/config/dataset/X--xyblocks_grey.yaml deleted file mode 100644 index 0faf884d..00000000 --- a/experiment/config/dataset/X--xyblocks_grey.yaml +++ /dev/null @@ -1,18 +0,0 @@ -defaults: - - _data_type_: gt - -name: xyblocks_grey - -data: - _target_: disent.dataset.data.XYBlocksData - rgb: FALSE - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: "${exit:EXITING... please compute the vis_mean and vis_std}" - vis_std: "${exit:EXITING... please compute the vis_mean and vis_std}" diff --git a/experiment/config/dataset/X--xysquares.yaml b/experiment/config/dataset/X--xysquares.yaml deleted file mode 100644 index e368ea3d..00000000 --- a/experiment/config/dataset/X--xysquares.yaml +++ /dev/null @@ -1,17 +0,0 @@ -defaults: - - _data_type_: gt - -name: xysquares_minimal - -data: - _target_: disent.dataset.data.XYSquaresMinimalData - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.015625, 0.015625, 0.015625] - vis_std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] diff --git a/experiment/config/dataset/X--xysquares_grey.yaml b/experiment/config/dataset/X--xysquares_grey.yaml deleted file mode 100644 index 20088abd..00000000 --- a/experiment/config/dataset/X--xysquares_grey.yaml +++ /dev/null @@ -1,23 +0,0 @@ -defaults: - - _data_type_: gt - -name: xysquares_grey - -data: - _target_: disent.dataset.data.XYSquaresData - square_size: 8 # AFFECTS: mean and std - image_size: 64 # usually ok to adjust - grid_size: 8 # usually ok to adjust - grid_spacing: 8 # usually ok to adjust - num_squares: 3 # AFFECTS: mean and std - rgb: FALSE # AFFECTS: mean and std - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [1, 64, 64] - vis_mean: [0.046146392822265625] - vis_std: [0.2096506119375896] diff --git a/experiment/config/dataset/X--xysquares_rgb.yaml b/experiment/config/dataset/X--xysquares_rgb.yaml deleted file mode 100644 index 35a45110..00000000 --- a/experiment/config/dataset/X--xysquares_rgb.yaml +++ /dev/null @@ -1,23 +0,0 @@ -defaults: - - _data_type_: gt - -name: xysquares_rgb - -data: - _target_: disent.dataset.data.XYSquaresData - square_size: 8 # AFFECTS: mean and std - image_size: 64 # usually ok to adjust - grid_size: 8 # usually ok to adjust - grid_spacing: 8 # usually ok to adjust - num_squares: 3 # AFFECTS: mean and std - rgb: TRUE # AFFECTS: mean and std - -transform: - _target_: disent.dataset.transform.ToImgTensorF32 - mean: ${dataset.meta.vis_mean} - std: ${dataset.meta.vis_std} - -meta: - x_shape: [3, 64, 64] - vis_mean: [0.015625, 0.015625, 0.015625] - vis_std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] diff --git a/experiment/config/framework/X--adaae.yaml b/experiment/config/framework/X--adaae.yaml deleted file mode 100644 index d492ca75..00000000 --- a/experiment/config/framework/X--adaae.yaml +++ /dev/null @@ -1,19 +0,0 @@ -defaults: - - _input_mode_: pair - -name: adaae - -cfg: - _target_: disent.frameworks.ae.experimental.AdaAe.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # disable various components - disable_decoder: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - # adavae - ada_thresh_ratio: 0.5 - -meta: - model_z_multiplier: 1 diff --git a/experiment/config/framework/X--adaae_os.yaml b/experiment/config/framework/X--adaae_os.yaml deleted file mode 100644 index 67a16b46..00000000 --- a/experiment/config/framework/X--adaae_os.yaml +++ /dev/null @@ -1,19 +0,0 @@ -defaults: - - _input_mode_: weak_pair - -name: adaae - -cfg: - _target_: disent.frameworks.ae.experimental.AdaAe.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # disable various components - disable_decoder: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - # adavae - ada_thresh_ratio: 0.5 - -meta: - model_z_multiplier: 1 diff --git a/experiment/config/framework/X--adaavetvae.yaml b/experiment/config/framework/X--adaavetvae.yaml deleted file mode 100644 index 03ae727e..00000000 --- a/experiment/config/framework/X--adaavetvae.yaml +++ /dev/null @@ -1,45 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: adaave_tvae - -cfg: - _target_: disent.frameworks.vae.experimental.AdaAveTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_average_mode: gvae - ada_thresh_mode: symmetric_kl # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_loss: triplet_soft_ave_all - adat_triplet_ratio: 1.0 # >> USE WITH A SCHEDULE << 0.5 is half of triplet and ada-triplet, 1.0 is all ada-triplet - adat_triplet_soft_scale: 1.0 # >> USE WITH A SCHEDULE << - adat_triplet_pull_weight: 0.1 # Only works for: adat_triplet_loss == "triplet_hard_neg_ave_pull" - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: posterior - adat_share_ave_mode: all # Only works for: adat_triplet_loss == "triplet_hard_ave_all" - # adaave_tvae - adaave_augment_orig: FALSE # triplet over original OR averaged embeddings - adaave_decode_orig: FALSE # decode & regularize original OR averaged embeddings - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--adanegtae.yaml b/experiment/config/framework/X--adanegtae.yaml deleted file mode 100644 index f5de4d33..00000000 --- a/experiment/config/framework/X--adanegtae.yaml +++ /dev/null @@ -1,27 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: adanegtae - -cfg: - _target_: disent.frameworks.ae.experimental.AdaNegTripletAe.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # disable various components - disable_decoder: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - -meta: - model_z_multiplier: 1 diff --git a/experiment/config/framework/X--adanegtvae.yaml b/experiment/config/framework/X--adanegtvae.yaml deleted file mode 100644 index a321400c..00000000 --- a/experiment/config/framework/X--adanegtvae.yaml +++ /dev/null @@ -1,37 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: adanegtvae - -cfg: - _target_: disent.frameworks.vae.experimental.AdaNegTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_average_mode: gvae - ada_thresh_mode: symmetric_kl # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: posterior - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--adatvae.yaml b/experiment/config/framework/X--adatvae.yaml deleted file mode 100644 index 0f822f24..00000000 --- a/experiment/config/framework/X--adatvae.yaml +++ /dev/null @@ -1,42 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: adatvae - -cfg: - _target_: disent.frameworks.vae.experimental.AdaTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_average_mode: gvae - ada_thresh_mode: symmetric_kl # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_loss: triplet_soft_ave_all - adat_triplet_ratio: 1.0 # >> USE WITH A SCHEDULE << 0.5 is half of triplet and ada-triplet, 1.0 is all ada-triplet - adat_triplet_soft_scale: 1.0 # >> USE WITH A SCHEDULE << - adat_triplet_pull_weight: 0.1 # Only works for: adat_triplet_loss == "triplet_hard_neg_ave_pull" - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: posterior - adat_share_ave_mode: all # Only works for: adat_triplet_loss == "triplet_hard_ave_all" - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--augpos_tvae_os.yaml b/experiment/config/framework/X--augpos_tvae_os.yaml deleted file mode 100644 index d2f72dfd..00000000 --- a/experiment/config/framework/X--augpos_tvae_os.yaml +++ /dev/null @@ -1,46 +0,0 @@ -defaults: - - _input_mode_: weak_pair - -name: augpos_tvae_os - -cfg: - _target_: disent.frameworks.vae.experimental.AugPosTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # overlap - overlap_augment: - _target_: disent.transform.FftBoxBlur - p: 1.0 - radius: [ 16, 16 ] - random_mode: "batch" - random_same_xy: TRUE - - # TODO: try original - # overlap_augment: - # size = a_x.shape[2:4] - # self._augment = torchvision.transforms.RandomOrder([ - # kornia.augmentation.ColorJitter(brightness=0.25, contrast=0.25, saturation=0, hue=0.15), - # kornia.augmentation.RandomCrop(size=size, padding=8), - # # kornia.augmentation.RandomPerspective(distortion_scale=0.05, p=1.0), - # # kornia.augmentation.RandomRotation(degrees=4), - # ]) - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--badavae.yaml b/experiment/config/framework/X--badavae.yaml deleted file mode 100644 index 000bd7f5..00000000 --- a/experiment/config/framework/X--badavae.yaml +++ /dev/null @@ -1,27 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: badavae - -cfg: - _target_: disent.frameworks.vae.experimental.BoundedAdaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # adavae - ada_average_mode: gvae # gvae or ml-vae - ada_thresh_mode: symmetric_kl - ada_thresh_ratio: 0.5 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--dorvae.yaml b/experiment/config/framework/X--dorvae.yaml deleted file mode 100644 index 8a2cb997..00000000 --- a/experiment/config/framework/X--dorvae.yaml +++ /dev/null @@ -1,38 +0,0 @@ -defaults: - - _input_mode_: single - -name: dor_vae - -cfg: - _target_: disent.frameworks.vae.experimental.DataOverlapRankVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # compatibility - ada_thresh_mode: dist # kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 - adat_triplet_share_scale: 0.95 - # dorvae - overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value - overlap_num: 512 - # dorvae -- representation loss - overlap_repr: deterministic # deterministic, stochastic - overlap_rank_mode: spearman_rank # spearman_rank, mse_rank - overlap_inward_pressure_masked: FALSE - overlap_inward_pressure_scale: 0.01 - # dorvae -- augment - overlap_augment_mode: 'none' - overlap_augment: NULL - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--dorvae_aug.yaml b/experiment/config/framework/X--dorvae_aug.yaml deleted file mode 100644 index a2aacc27..00000000 --- a/experiment/config/framework/X--dorvae_aug.yaml +++ /dev/null @@ -1,43 +0,0 @@ -defaults: - - _input_mode_: single - -name: dor_vae_aug - -cfg: - _target_: disent.frameworks.vae.experimental.DataOverlapRankVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # compatibility - ada_thresh_mode: dist # kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 - adat_triplet_share_scale: 0.95 - # dorvae - overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value - overlap_num: 512 - # dorvae -- representation loss - overlap_repr: deterministic # deterministic, stochastic - overlap_rank_mode: spearman_rank # spearman_rank, mse_rank - overlap_inward_pressure_masked: FALSE - overlap_inward_pressure_scale: 0.01 - # dorvae -- augment - overlap_augment_mode: 'augment' - overlap_augment: - _target_: disent.transform.FftBoxBlur - p: 1.0 - radius: [16, 16] - random_mode: "batch" - random_same_xy: TRUE - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--dotae.yaml b/experiment/config/framework/X--dotae.yaml deleted file mode 100644 index b496247a..00000000 --- a/experiment/config/framework/X--dotae.yaml +++ /dev/null @@ -1,35 +0,0 @@ -defaults: - - _input_mode_: single - -name: dotae - -cfg: - _target_: disent.frameworks.ae.experimental.DataOverlapTripletAe.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # disable various components - disable_decoder: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # dotvae - overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value - overlap_num: 512 - overlap_mine_ratio: 0.1 - overlap_mine_triplet_mode: 'none' # none, hard_neg, semi_hard_neg, hard_pos, easy_pos, ran:hard_neg+hard_pos <- etc, dynamically evaluated, can chain multiple "+"s - # dotvae -- augment - overlap_augment_mode: 'none' - overlap_augment: NULL - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--dotvae.yaml b/experiment/config/framework/X--dotvae.yaml deleted file mode 100644 index c473f15d..00000000 --- a/experiment/config/framework/X--dotvae.yaml +++ /dev/null @@ -1,45 +0,0 @@ -defaults: - - _input_mode_: single - -name: do_tvae - -cfg: - _target_: disent.frameworks.vae.experimental.DataOverlapTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_average_mode: gvae - ada_thresh_mode: dist # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: posterior - # dotvae - overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value - overlap_num: 512 - overlap_mine_ratio: 0.1 - overlap_mine_triplet_mode: 'none' # none, hard_neg, semi_hard_neg, hard_pos, easy_pos, ran:hard_neg+hard_pos <- etc, dynamically evaluated, can chain multiple "+"s - # dotvae -- augment - overlap_augment_mode: 'none' - overlap_augment: NULL - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--dotvae_aug.yaml b/experiment/config/framework/X--dotvae_aug.yaml deleted file mode 100644 index df6c527d..00000000 --- a/experiment/config/framework/X--dotvae_aug.yaml +++ /dev/null @@ -1,70 +0,0 @@ -defaults: - - _input_mode_: single - -name: do_tvae_aug - -cfg: - _target_: disent.frameworks.vae.experimental.DataOverlapTripletVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - # adavae - ada_average_mode: gvae - ada_thresh_mode: dist # Only works for: adat_share_mask_mode == "posterior" --- kl, symmetric_kl, dist, sampled_dist - ada_thresh_ratio: 0.5 # >> USE WITH A SCHEDULE << - # ada_tvae - loss - adat_triplet_share_scale: 0.95 # >> USE WITH A SCHEDULE << only works for: adat_triplet_loss == "triplet_hard_neg_ave_scaled" - # ada_tvae - averaging - adat_share_mask_mode: posterior - # dotvae - overlap_loss: ${settings.framework_opt.overlap_loss} # any of the recon_loss values, or NULL to use the recon_loss value - overlap_num: 512 - overlap_mine_ratio: 0.1 - overlap_mine_triplet_mode: 'ran:hard_neg+easy_pos' # none, hard_neg, semi_hard_neg, hard_pos, easy_pos, ran:hard_neg+hard_pos <- etc, dynamically evaluated, can chain multiple "+"s - # dotvae -- augment - overlap_augment_mode: 'augment' - overlap_augment: - _target_: disent.transform.FftKernel - kernel: xy1_r47 - -# overlap_augment: -# _target_: disent.transform.FftBoxBlur -# p: 1.0 -# radius: [16, 16] -# random_mode: "batch" -# random_same_xy: TRUE -# - _target_: disent.transform.FftGaussianBlur -# p: 1.0 -# sigma: [0.1, 10.0] -# truncate: 3.0 -# random_mode: "batch" -# random_same_xy: FALSE -# - _target_: kornia.augmentation.RandomCrop -# p: 1.0 -# size: [64, 64] -# padding: 7 -# - _target_: kornia.augmentation.RandomPerspective -# p: 0.5 -# distortion_scale: 0.15 -# - _target_: kornia.augmentation.RandomRotation -# p: 0.5 -# degrees: 9 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--gadavae.yaml b/experiment/config/framework/X--gadavae.yaml deleted file mode 100644 index 0a830662..00000000 --- a/experiment/config/framework/X--gadavae.yaml +++ /dev/null @@ -1,29 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: gadavae - -cfg: - _target_: disent.frameworks.vae.experimental.GuidedAdaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # adavae - ada_average_mode: gvae # gvae or ml-vae - ada_thresh_mode: symmetric_kl - ada_thresh_ratio: 0.5 - # guided adavae - gada_anchor_ave_mode: 'average' # [average, thresh] - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--st-adavae.yaml b/experiment/config/framework/X--st-adavae.yaml deleted file mode 100644 index ffcaf36f..00000000 --- a/experiment/config/framework/X--st-adavae.yaml +++ /dev/null @@ -1,29 +0,0 @@ -defaults: - - _input_mode_: pair - -name: st-adavae - -cfg: - _target_: disent.frameworks.vae.experimental.SwappedTargetAdaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # adavae - ada_average_mode: gvae # gvae or ml-vae - ada_thresh_mode: symmetric_kl - ada_thresh_ratio: 0.5 - # swapped target - swap_chance: 0.1 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--st-betavae.yaml b/experiment/config/framework/X--st-betavae.yaml deleted file mode 100644 index d2273212..00000000 --- a/experiment/config/framework/X--st-betavae.yaml +++ /dev/null @@ -1,25 +0,0 @@ -defaults: - - _input_mode_: pair - -name: st-betavae - -cfg: - _target_: disent.frameworks.vae.experimental.SwappedTargetBetaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # swapped target - swap_chance: 0.1 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--tbadavae.yaml b/experiment/config/framework/X--tbadavae.yaml deleted file mode 100644 index d6b4d3ad..00000000 --- a/experiment/config/framework/X--tbadavae.yaml +++ /dev/null @@ -1,33 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: tbadavae - -cfg: - _target_: disent.frameworks.vae.experimental.TripletBoundedAdaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # adavae - ada_average_mode: gvae # gvae or ml-vae - ada_thresh_mode: symmetric_kl - ada_thresh_ratio: 0.5 - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/framework/X--tgadavae.yaml b/experiment/config/framework/X--tgadavae.yaml deleted file mode 100644 index 5d24f7e8..00000000 --- a/experiment/config/framework/X--tgadavae.yaml +++ /dev/null @@ -1,35 +0,0 @@ -defaults: - - _input_mode_: triplet - -name: tgadavae - -cfg: - _target_: disent.frameworks.vae.experimental.TripletGuidedAdaVae.cfg - # base ae - recon_loss: ${settings.framework.recon_loss} - loss_reduction: ${settings.framework.loss_reduction} - # base vae - latent_distribution: ${settings.framework_opt.latent_distribution} - # disable various components - disable_decoder: FALSE - disable_reg_loss: FALSE - disable_rec_loss: FALSE - disable_aug_loss: FALSE - disable_posterior_scale: NULL - # Beta-VAE - beta: ${settings.framework.beta} - # adavae - ada_average_mode: gvae # gvae or ml-vae - ada_thresh_mode: symmetric_kl - ada_thresh_ratio: 0.5 - # guided adavae - gada_anchor_ave_mode: 'average' # [average, thresh] - # tvae: triplet stuffs - triplet_loss: triplet - triplet_margin_min: 0.001 - triplet_margin_max: 1 - triplet_scale: 0.1 - triplet_p: 1 - -meta: - model_z_multiplier: 2 diff --git a/experiment/config/metrics/all.yaml b/experiment/config/metrics/all.yaml index fef33bb2..cc554da5 100644 --- a/experiment/config/metrics/all.yaml +++ b/experiment/config/metrics/all.yaml @@ -1,6 +1,4 @@ metric_list: - - flatness: {} # pragma: delete-on-release - - flatness_components: {} # pragma: delete-on-release - mig: {} - sap: {} - dci: diff --git a/experiment/config/metrics/fast.yaml b/experiment/config/metrics/fast.yaml index 1d776029..71853977 100644 --- a/experiment/config/metrics/fast.yaml +++ b/experiment/config/metrics/fast.yaml @@ -1,6 +1,4 @@ metric_list: - - flatness: {} # pragma: delete-on-release - - flatness_components: {} # pragma: delete-on-release - mig: {} - sap: {} - unsupervised: {} diff --git a/experiment/config/metrics/test.yaml b/experiment/config/metrics/test.yaml index d19f2326..698ed061 100644 --- a/experiment/config/metrics/test.yaml +++ b/experiment/config/metrics/test.yaml @@ -1,8 +1,4 @@ metric_list: - - flatness: # pragma: delete-on-release - every_n_steps: 110 # pragma: delete-on-release - - flatness_components: # pragma: delete-on-release - every_n_steps: 111 # pragma: delete-on-release - mig: every_n_steps: 112 - sap: diff --git a/experiment/config/run_location/griffin.yaml b/experiment/config/run_location/griffin.yaml deleted file mode 100644 index bd7634b0..00000000 --- a/experiment/config/run_location/griffin.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# @package _global_ - -dsettings: - trainer: - cuda: TRUE - storage: - logs_dir: 'logs' - data_root: '${oc.env:HOME}/workspace/research/disent/data/dataset' - dataset: - gpu_augment: FALSE - prepare: TRUE - try_in_memory: TRUE - -trainer: - prepare_data_per_node: TRUE - -dataloader: - num_workers: 32 # max 128, more than 16 doesn't really seem to help (tested on batch size 256*3)? - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} - -hydra: - job: - name: 'disent' - run: - dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - sweep: - dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/experiment/config/run_location/heartofgold.yaml b/experiment/config/run_location/heartofgold.yaml deleted file mode 100644 index 4e5bf8e3..00000000 --- a/experiment/config/run_location/heartofgold.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# @package _global_ - -dsettings: - trainer: - cuda: TRUE - storage: - logs_dir: 'logs' - data_root: '${oc.env:HOME}/workspace/research/disent/data/dataset' - dataset: - gpu_augment: FALSE - prepare: TRUE - try_in_memory: TRUE - -trainer: - prepare_data_per_node: TRUE - -dataloader: - num_workers: 12 - pin_memory: ${dsettings.trainer.cuda} - batch_size: ${settings.dataset.batch_size} - -hydra: - job: - name: 'disent' - run: - dir: '${dsettings.storage.logs_dir}/hydra_run/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - sweep: - dir: '${dsettings.storage.logs_dir}/hydra_sweep/${now:%Y-%m-%d_%H-%M-%S}_${hydra.job.name}' - subdir: '${hydra.job.id}' # hydra.job.id is not available for dir diff --git a/prepare_release.sh b/prepare_release.sh deleted file mode 100755 index da71f5c2..00000000 --- a/prepare_release.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash - -# prepare the project for a new release -# removing all the research components -# - yes this is terrible, but at the rate things are changing I -# don't want to rip things out into a separate repo... I will -# do that eventually, but not now. - -# ====== # -# HELPER # -# ====== # - -function remove_delete_commands() { - awk "!/pragma: delete-on-release/" "$1" > "$1.temp" && mv "$1.temp" "$1" -} - -function version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort --check=quiet --version-sort -} - -# check that we have the right version so -# that `shopt -s globstar` does not fail -if ! version_greater_equal "$BASH_VERSION" "4"; then - echo "bash version 4 is required, got: ${BASH_VERSION}" - exit 1 -fi - -# ============ # -# DELETE FILES # -# ============ # - -# RESEARCH: -rm requirements-research.txt -rm requirements-research-freeze.txt -rm -rf research/ - -# EXPERIMENT: -rm experiment/config/config_adversarial_dataset.yaml -rm experiment/config/config_adversarial_dataset_approx.yaml -rm experiment/config/config_adversarial_kernel.yaml -rm experiment/config/run_location/griffin.yaml -rm experiment/config/run_location/heartofgold.yaml -rm experiment/config/dataset/X--*.yaml -rm experiment/config/framework/X--*.yaml - -# DISENT: -# - metrics -rm disent/metrics/_flatness.py -rm disent/metrics/_flatness_components.py -# - frameworks -rm -rf disent/frameworks/ae/experimental -rm -rf disent/frameworks/vae/experimental -# - datasets -rm disent/dataset/data/_groundtruth__xcolumns.py -rm disent/dataset/data/_groundtruth__xysquares.py -rm disent/dataset/data/_groundtruth__xyblocks.py - -# DATA: -# - disent.framework.helper -rm -rf data/adversarial_kernel - -# ===================== # -# DELETE LINES OF FILES # -# ===================== # - -# enable recursive glob -shopt -s globstar - -# scan for all files that contain 'pragma: delete-on-release' -for file in **/*.{py,yaml}; do - if [ -n "$( grep -m 1 'pragma: delete-on-release' "$file" )" ]; then - echo "preparing: $file" - remove_delete_commands "$file" - fi -done - -# ===================== # -# CLEANUP THIS FILE # -# ===================== # - -rm prepare_release.sh -rm prepare_release_and_commit.sh diff --git a/prepare_release_and_commit.sh b/prepare_release_and_commit.sh deleted file mode 100755 index 1cd513bc..00000000 --- a/prepare_release_and_commit.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - - -# ====== # -# HELPER # -# ====== # - -function version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort --check=quiet --version-sort -} - -# check that we have the right version so -# that `shopt -s globstar` does not fail -if ! version_greater_equal "$BASH_VERSION" "4"; then - echo "bash version 4 is required, got: ${BASH_VERSION}" - exit 1 -fi - -# ====== # -# RUN # -# ====== # - -echo "(1/3) [GIT] Creating Prepare Branch" && \ - git checkout -b xdev-prepare && \ - ( git branch --unset-upstream 2>/dev/null || true ) && \ - \ - echo "(2/3) [PREPARE]" && \ - bash ./prepare_release.sh && \ - \ - echo "(3/3) [GIT] Committing Files" && \ - git add . && \ - git commit -m "run prepare_release.sh" - -# echo "(4/4) [GIT] Merging Changes" && \ -# git checkout dev && \ -# git merge xdev-prepare diff --git a/requirements-research-freeze.txt b/requirements-research-freeze.txt deleted file mode 100644 index 4e7f5b69..00000000 --- a/requirements-research-freeze.txt +++ /dev/null @@ -1,121 +0,0 @@ -# freeze from griffin on 2021-09-29 at 15:20 -# - Python 3.8.11 is used with miniconda-latest installed with pyenv -# - There are lots of unnecessary requirements in this list -# some have been generated with side experiments and other installs -# but experiments are confirmed to work locally with this list -# SLURM, still needs to be tested an might be broken with this. -# - install with: $ pip install --no-deps --ignore-installed -r requirements-research-freeze.txt -absl-py==0.13.0 -aiohttp==3.7.4.post0 -antlr4-python3-runtime==4.8 -argcomplete==1.12.3 -async-timeout==3.0.1 -attrs==21.2.0 -beautifulsoup4==4.10.0 -cachetools==4.2.2 -certifi==2021.5.30 -chardet==4.0.0 -charset-normalizer==2.0.4 -click==8.0.1 -cloudpickle==1.6.0 -colorlog==5.0.1 -configparser==5.0.2 -coverage==5.5 -cycler==0.10.0 -deap==1.3.1 -decorator==4.4.2 -deltas==0.7.0 -Deprecated==1.2.12 -diskcache==5.2.1 -docker-pycreds==0.4.0 -evaluations==0.0.5 -filelock==3.0.12 -fsspec==2021.7.0 -future==0.18.2 -generations==1.3.0 -gitdb==4.0.7 -GitPython==3.1.18 -google-auth==1.35.0 -google-auth-oauthlib==0.4.5 -grpcio==1.39.0 -h5py==3.3.0 -hydra-colorlog==1.0.1 -hydra-core==1.0.7 -hydra-submitit-launcher==1.1.1 -idna==3.2 -imageio==2.9.0 -imageio-ffmpeg==0.4.4 -importlib-resources==5.2.2 -iniconfig==1.1.1 -joblib==1.0.1 -kiwisolver==1.3.1 -llvmlite==0.37.0 -Logbook==1.5.3 -Markdown==3.3.4 -matplotlib==3.4.3 -member==0.0.1 -moviepy==1.0.3 -msgpack==1.0.2 -multidict==5.1.0 -numba==0.54.0 -numpy==1.20.3 -oauthlib==3.1.1 -offspring==0.1.1 -omegaconf==2.0.6 -packaging==21.0 -pathtools==0.1.2 -Pillow==8.3.1 -plotly==5.3.1 -pluggy==0.13.1 -population==0.0.1 -proglog==0.1.9 -promise==2.3 -protobuf==3.17.3 -psutil==5.8.0 -py==1.10.0 -pyasn1==0.4.8 -pyasn1-modules==0.2.8 -pyDeprecate==0.3.1 -pyparsing==2.4.7 -pytest==6.2.4 -pytest-cov==2.12.1 -python-dateutil==2.8.2 -pytorch-lightning==1.4.2 -PyYAML==5.4.1 -ray==1.6.0 -redis==3.5.3 -requests==2.26.0 -requests-oauthlib==1.3.0 -rsa==4.7.2 -ruck==0.2.2 -scikit-learn==0.24.2 -scipy==1.7.1 -sentry-sdk==1.3.1 -shortuuid==1.0.1 -six==1.16.0 -sklearn-genetic==0.4.1 -smmap==4.0.0 -soupsieve==2.2.1 -submitit==1.3.3 -subprocess32==3.5.4 -tenacity==8.0.1 -tensorboard==2.6.0 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.0 -threadpoolctl==2.2.0 -tldr==2.0.0 -toml==0.10.2 -torch==1.9.1 -torchmetrics==0.5.0 -torchsort==0.1.6 -torchvision==0.10.1 -tqdm==4.62.1 -triton==1.0.0 -typing-extensions==3.10.0.0 -urllib3==1.26.6 -wandb==0.12.0 -Werkzeug==2.0.1 -wrapt==1.12.1 -yamlconf==0.2.4 -yarl==1.6.3 -zipp==3.5.0 diff --git a/requirements-research.txt b/requirements-research.txt deleted file mode 100644 index 93d0e876..00000000 --- a/requirements-research.txt +++ /dev/null @@ -1,19 +0,0 @@ - -# TODO: these requirements need to be cleaned up! - -# MISSING DEPS - these are imported in /research, but not included here, in requirements.txt OR in requirements-experiment.txt -# ============= - -# github -# matplotlib -# psutil - - -ray>=1.6.0 -ruck==0.2.4 - -seaborn>=0.11.0 -pandas>=1.3.0 -cachier>=1.5.0 - -statsmodels>=0.13.0 # required for seaborn, to estimate outliers in regression plots diff --git a/research/__init__.py b/research/__init__.py deleted file mode 100644 index 9a05a479..00000000 --- a/research/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ diff --git a/research/clog-batch.sh b/research/clog-batch.sh deleted file mode 100644 index 79602bf8..00000000 --- a/research/clog-batch.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export PROJECT="N/A" -export USERNAME="N/A" -export PARTITION="batch" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(realpath -s "$0")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cuda_nodes "$PARTITION" 43200 "C-disent" # 12 hours diff --git a/research/clog-stampede.sh b/research/clog-stampede.sh deleted file mode 100644 index 5d53a26e..00000000 --- a/research/clog-stampede.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export PROJECT="N/A" -export USERNAME="N/A" -export PARTITION="stampede" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(realpath -s "$0")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cuda_nodes "$PARTITION" 43200 "C-disent" # 12 hours diff --git a/research/e00_data_traversal/plots/.gitignore b/research/e00_data_traversal/plots/.gitignore deleted file mode 100644 index e33609d2..00000000 --- a/research/e00_data_traversal/plots/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.png diff --git a/research/e00_data_traversal/run_01_all_shared_data_prepare.sh b/research/e00_data_traversal/run_01_all_shared_data_prepare.sh deleted file mode 100644 index bdd2026e..00000000 --- a/research/e00_data_traversal/run_01_all_shared_data_prepare.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -# This script is intended to prepare all shared data on the wits cluster -# you can probably modify it for your own purposes -# - data is loaded and processed into ~/downloads/datasets which is a -# shared drive, instead of /tmp/, which is a local drive. - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="prepare-data" -export PARTITION="stampede" -export PARALLELISM=32 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -DATASETS=( - cars3d - dsprites - # monte_rollouts - # mpi3d_real - # mpi3d_realistic - # mpi3d_toy - shapes3d - smallnorb - #X--adv-cars3d--WARNING - #X--adv-dsprites--WARNING - #X--adv-shapes3d--WARNING - #X--adv-smallnorb--WARNING - #X--dsprites-imagenet - #X--dsprites-imagenet-bg-20 - #X--dsprites-imagenet-bg-40 - #X--dsprites-imagenet-bg-60 - #X--dsprites-imagenet-bg-80 - #X--dsprites-imagenet-bg-100 - #X--dsprites-imagenet-fg-20 - #X--dsprites-imagenet-fg-40 - #X--dsprites-imagenet-fg-60 - #X--dsprites-imagenet-fg-80 - #X--dsprites-imagenet-fg-100 - #X--mask-adv-f-cars3d - #X--mask-adv-f-dsprites - #X--mask-adv-f-shapes3d - #X--mask-adv-f-smallnorb - #X--mask-adv-r-cars3d - #X--mask-adv-r-dsprites - #X--mask-adv-r-shapes3d - #X--mask-adv-r-smallnorb - #X--mask-dthr-cars3d - #X--mask-dthr-dsprites - #X--mask-dthr-shapes3d - #X--mask-dthr-smallnorb - #X--mask-ran-cars3d - #X--mask-ran-dsprites - #X--mask-ran-shapes3d - #X--mask-ran-smallnorb - "X--xyblocks" - #X--xyblocks_grey - "X--xysquares" - #X--xysquares_grey - #X--xysquares_rgb - xyobject - #xyobject_grey - #xyobject_shaded - #xyobject_shaded_grey -) - -local_sweep \ - run_action=prepare_data \ - run_location=stampede_shr \ - run_launcher=local \ - dataset="$(IFS=, ; echo "${DATASETS[*]}")" diff --git a/research/e00_data_traversal/run_02_plot_data_overlap.py b/research/e00_data_traversal/run_02_plot_data_overlap.py deleted file mode 100644 index 6712defb..00000000 --- a/research/e00_data_traversal/run_02_plot_data_overlap.py +++ /dev/null @@ -1,187 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import os -from typing import Optional - -import numpy as np -from matplotlib import pyplot as plt - -import research.util as H -from disent.dataset import DisentDataset -from disent.dataset.data import Cars3dData -from disent.dataset.data import DSpritesData -from disent.dataset.data import DSpritesImagenetData -from disent.dataset.data import GroundTruthData -from disent.dataset.data import SelfContainedHdf5GroundTruthData -from disent.dataset.data import Shapes3dData -from disent.dataset.data import SmallNorbData -from disent.dataset.data import XYBlocksData -from disent.dataset.data import XYObjectData -from disent.dataset.data import XYObjectShadedData -from disent.dataset.data import XYSquaresData -from disent.util.function import wrapped_partial -from disent.util.seeds import TempNumpySeed - - -# ========================================================================= # -# core # -# ========================================================================= # - - -def ensure_rgb(img: np.ndarray) -> np.ndarray: - if img.shape[-1] == 1: - img = np.concatenate([img, img, img], axis=-1) - assert img.shape[-1] == 3, f'last channel of array is not of size 3 for RGB, got shape: {tuple(img.shape)}' - return img - - -def plot_dataset_overlap( - gt_data: GroundTruthData, - f_idxs=None, - obs_max: Optional[int] = None, - obs_spacing: int = 1, - rel_path=None, - plot_base: bool = False, - plot_combined: bool = True, - plot_sidebar: bool = False, - save=True, - seed=777, - plt_scale=4.5, - offset=0.75, -): - with TempNumpySeed(seed): - # choose an f_idx - f_idx = np.random.choice(gt_data.normalise_factor_idxs(f_idxs)) - f_name = gt_data.factor_names[f_idx] - num_cols = gt_data.factor_sizes[f_idx] - # get a traversal - factors, indices, obs = gt_data.sample_random_obs_traversal(f_idx=f_idx) - # get subset - if obs_max is not None: - max_obs_spacing, i = obs_spacing, 1 - while max_obs_spacing*obs_max > len(obs): - max_obs_spacing = obs_spacing-i - i += 1 - i = max((len(obs) - obs_max*max_obs_spacing) // 2, 0) - obs = obs[i:i+obs_max*obs_spacing:max_obs_spacing][:obs_max] - # convert - obs = np.array([ensure_rgb(x) for x in obs], dtype='float32') / 255 - # compute the distances - grid = np.zeros([len(obs), len(obs), *obs[0].shape]) - for i, i_obs in enumerate(obs): - for j, j_obs in enumerate(obs): - grid[i, j] = np.abs(i_obs - j_obs) - # normalize - grid /= grid.max() - - # make figure - factors, frames, _, _, c = grid.shape - assert c == 3 - - if plot_base: - # plot - fig, axs = H.plt_subplots_imshow(grid, label_size=18, title_size=24, title=f'{gt_data.name}: {f_name}', subplot_padding=None, figsize=(offset + (1/2.54)*frames*plt_scale, (1/2.54)*(factors+0.45)*plt_scale)) - # save figure - if save and (rel_path is not None): - path = H.make_rel_path_add_ext(rel_path, ext='.png') - plt.savefig(path) - print(f'saved: {repr(path)}') - plt.show() - - if plot_combined: - # add obs - if True: - factors += 1 - frames += 1 - # scaled_obs = obs - scaled_obs = obs * 0.5 + 0.25 - # grid = 1 - grid - # grid = grid * 0.5 + 0.25 - grid = np.concatenate([scaled_obs[None, :], grid], axis=0) - add_row = np.concatenate([np.ones_like(obs[0:1]), scaled_obs], axis=0) - grid = np.concatenate([grid, add_row[:, None]], axis=1) - # plot - fig, axs = H.plt_subplots_imshow(grid, label_size=18, title_size=24, row_labels=["traversal"] + (["diff."] * len(obs)), col_labels=(["diff."] * len(obs)) + ["traversal"], title=f'{gt_data.name}: {f_name}', subplot_padding=None, figsize=(offset + (1/2.54)*frames*plt_scale, (1/2.54)*(factors+0.45)*plt_scale)) - # save figure - if save and (rel_path is not None): - path = H.make_rel_path_add_ext(rel_path + '__combined', ext='.png') - plt.savefig(path) - print(f'saved: {repr(path)}') - plt.show() - - # plot - if plot_sidebar: - fig, axs = H.plt_subplots_imshow(obs[:, None], subplot_padding=None, figsize=(offset + (1/2.54)*1*plt_scale, (1/2.54)*(factors+0.45)*plt_scale)) - if save and (rel_path is not None): - path = H.make_rel_path_add_ext(rel_path + '__v', ext='.png') - plt.savefig(path) - print(f'saved: {repr(path)}') - plt.show() - fig, axs = H.plt_subplots_imshow(obs[None, :], subplot_padding=None, figsize=(offset + (1/2.54)*frames*plt_scale, (1/2.54)*(1+0.45)*plt_scale)) - if save and (rel_path is not None): - path = H.make_rel_path_add_ext(rel_path + '__h', ext='.png') - plt.savefig(path) - print(f'saved: {repr(path)}') - plt.show() - - -# ========================================================================= # -# entrypoint # -# ========================================================================= # - - -if __name__ == '__main__': - - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # options - all_squares = True - add_random_traversal = True - num_cols = 7 - seed = 48 - - for gt_data_cls, name in [ - (wrapped_partial(XYSquaresData, grid_spacing=1, grid_size=8, no_warnings=True), f'xy-squares-spacing1'), - (wrapped_partial(XYSquaresData, grid_spacing=2, grid_size=8, no_warnings=True), f'xy-squares-spacing2'), - (wrapped_partial(XYSquaresData, grid_spacing=4, grid_size=8, no_warnings=True), f'xy-squares-spacing4'), - (wrapped_partial(XYSquaresData, grid_spacing=8, grid_size=8, no_warnings=True), f'xy-squares-spacing8'), - ]: - plot_dataset_overlap(gt_data_cls(), rel_path=f'plots/overlap__{name}', obs_max=3, obs_spacing=4, seed=seed-40) - - for gt_data_cls, name in [ - (DSpritesData, f'dsprites'), - (Shapes3dData, f'shapes3d'), - (Cars3dData, f'cars3d'), - (SmallNorbData, f'smallnorb'), - ]: - gt_data = gt_data_cls() - for f_idx, f_name in enumerate(gt_data.factor_names): - plot_dataset_overlap(gt_data, rel_path=f'plots/overlap__{name}__f{f_idx}-{f_name}', obs_max=3, obs_spacing=4, f_idxs=f_idx, seed=seed) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e00_data_traversal/run_02_plot_traversals.py b/research/e00_data_traversal/run_02_plot_traversals.py deleted file mode 100644 index 15142441..00000000 --- a/research/e00_data_traversal/run_02_plot_traversals.py +++ /dev/null @@ -1,255 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import os -from typing import Optional -from typing import Sequence -from typing import Union - -import numpy as np -from matplotlib import pyplot as plt - -import research.util as H -from disent.dataset import DisentDataset -from disent.dataset.data import Cars3dData -from disent.dataset.data import DSpritesData -from disent.dataset.data import DSpritesImagenetData -from disent.dataset.data import GroundTruthData -from disent.dataset.data import SelfContainedHdf5GroundTruthData -from disent.dataset.data import Shapes3dData -from disent.dataset.data import SmallNorbData -from disent.dataset.data import XYBlocksData -from disent.dataset.data import XYObjectData -from disent.dataset.data import XYObjectShadedData -from disent.dataset.data import XYSquaresData -from disent.util.seeds import TempNumpySeed - - -# ========================================================================= # -# core # -# ========================================================================= # - - -def ensure_rgb(img: np.ndarray) -> np.ndarray: - if img.shape[-1] == 1: - img = np.concatenate([img, img, img], axis=-1) - assert img.shape[-1] == 3, f'last channel of array is not of size 3 for RGB, got shape: {tuple(img.shape)}' - return img - - -def plot_dataset_traversals( - gt_data: GroundTruthData, - f_idxs=None, - num_cols: Optional[int] = 8, - take_cols: Optional[int] = None, - base_factors=None, - add_random_traversal: bool = True, - pad: int = 8, - bg_color: int = 127, - border: bool = False, - rel_path: str = None, - save: bool = True, - seed: int = 777, - plt_scale: float = 4.5, - offset: float = 0.75, - transpose: bool = False, - title: Union[bool, str] = True, - label_size: int = 22, - title_size: int = 26, - labels_at_top: bool = False, -): - if take_cols is not None: - assert take_cols >= num_cols - # convert - dataset = DisentDataset(gt_data) - f_idxs = gt_data.normalise_factor_idxs(f_idxs) - num_cols = num_cols if (num_cols is not None) else min(max(gt_data.factor_sizes), 32) - # get traversal grid - row_labels = [gt_data.factor_names[i] for i in f_idxs] - grid, _, _ = H.visualize_dataset_traversal( - dataset=dataset, - data_mode='raw', - factor_names=f_idxs, - num_frames=num_cols if (take_cols is None) else take_cols, - seed=seed, - base_factors=base_factors, - traverse_mode='interval', - pad=pad, - bg_color=bg_color, - border=border, - ) - if take_cols is not None: - grid = grid[:, :num_cols, ...] - # add random traversal - if add_random_traversal: - with TempNumpySeed(seed): - row_labels = ['random'] + row_labels - row = dataset.dataset_sample_batch(num_samples=num_cols, mode='raw')[None, ...] # torch.Tensor - grid = np.concatenate([ensure_rgb(row), grid]) - # make figure - factors, frames, _, _, c = grid.shape - assert c == 3 - - # get title - if isinstance(title, bool): - title = gt_data.name if title else None - - if transpose: - col_titles = None - if labels_at_top: - col_titles, row_labels = row_labels, None - fig, axs = H.plt_subplots_imshow(np.swapaxes(grid, 0, 1), label_size=label_size, title_size=title_size, title=title, titles=col_titles, titles_size=label_size, col_labels=row_labels, subplot_padding=None, figsize=(offset + (1/2.54)*frames*plt_scale, (1/2.54)*(factors+0.45)*plt_scale)[::-1]) - else: - fig, axs = H.plt_subplots_imshow(grid, label_size=label_size, title_size=title_size, title=title, row_labels=row_labels, subplot_padding=None, figsize=(offset + (1/2.54)*frames*plt_scale, (1/2.54)*(factors+0.45)*plt_scale)) - - # save figure - if save and (rel_path is not None): - path = H.make_rel_path_add_ext(rel_path, ext='.png') - plt.savefig(path) - print(f'saved: {repr(path)}') - plt.show() - # done! - return fig, axs - - -def plot_incr_overlap( - rel_path: Optional[str] = None, - spacings: Union[Sequence[int], bool] = False, - seed: int = 777, - fidx: int = 1, - traversal_size: int = 8, - traversal_lim: Optional[int] = None, - save: bool = True, - show: bool = True -): - if isinstance(spacings, bool): - spacings = ([1, 2, 3, 4, 5, 6, 7, 8] if spacings else [1, 4, 8]) - - if traversal_lim is None: - traversal_lim = traversal_size - assert traversal_size >= traversal_lim - - grid = [] - for s in spacings: - data = XYSquaresData(grid_spacing=s, grid_size=8, no_warnings=True) - with TempNumpySeed(seed): - factors, indices, obs = data.sample_random_obs_traversal(f_idx=data.normalise_factor_idx(fidx), num=traversal_size, mode='interval') - grid.append(obs[:traversal_lim]) - - w, h = traversal_lim * 2.54, len(spacings) * 2.54 - fig, axs = H.plt_subplots_imshow(grid, row_labels=[f'Space: {s}px' for s in spacings], figsize=(w, h), label_size=24) - fig.tight_layout() - - H.plt_rel_path_savefig(rel_path=rel_path, save=save, show=show) - - -# ========================================================================= # -# entrypoint # -# ========================================================================= # - - -if __name__ == '__main__': - - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # options - all_squares = False - num_cols = 7 - mini_cols = 5 - transpose_cols = 3 - seed = 47 - - INCLUDE_RANDOM_TRAVERSAL = False - TITLE = False - TITLE_MINI = False - TITLE_TRANSPOSE = False - - # get name - prefix = 'traversal' if INCLUDE_RANDOM_TRAVERSAL else 'traversal-noran' - - # plot increasing levels of overlap - plot_incr_overlap(rel_path=f'plots/traversal-incr-overlap__xy-squares', save=True, show=True, traversal_lim=None) - - # mini versions - plot_dataset_traversals(XYSquaresData(), rel_path=f'plots/traversal-mini__xy-squares__spacing8', title=TITLE_MINI, seed=seed, transpose=False, add_random_traversal=False, num_cols=mini_cols) - plot_dataset_traversals(Shapes3dData(), rel_path=f'plots/traversal-mini__shapes3d', title=TITLE_MINI, seed=seed, transpose=False, add_random_traversal=False, num_cols=mini_cols) - plot_dataset_traversals(DSpritesData(), rel_path=f'plots/traversal-mini__dsprites', title=TITLE_MINI, seed=seed, transpose=False, add_random_traversal=False, num_cols=mini_cols) - plot_dataset_traversals(SmallNorbData(), rel_path=f'plots/traversal-mini__smallnorb', title=TITLE_MINI, seed=seed, transpose=False, add_random_traversal=False, num_cols=mini_cols) - plot_dataset_traversals(Cars3dData(), rel_path=f'plots/traversal-mini__cars3d', title=TITLE_MINI, seed=seed, transpose=False, add_random_traversal=False, num_cols=mini_cols, take_cols=mini_cols+1) - - # transpose versions - plot_dataset_traversals(XYSquaresData(), rel_path=f'plots/traversal-transpose__xy-squares__spacing8', title=TITLE_TRANSPOSE, offset=0.95, label_size=23, seed=seed, labels_at_top=True, transpose=True, add_random_traversal=False, num_cols=transpose_cols) - plot_dataset_traversals(Shapes3dData(), rel_path=f'plots/traversal-transpose__shapes3d', title=TITLE_TRANSPOSE, offset=0.95, label_size=23, seed=seed, labels_at_top=True, transpose=True, add_random_traversal=False, num_cols=transpose_cols) - plot_dataset_traversals(DSpritesData(), rel_path=f'plots/traversal-transpose__dsprites', title=TITLE_TRANSPOSE, offset=0.95, label_size=23, seed=seed, labels_at_top=True, transpose=True, add_random_traversal=False, num_cols=transpose_cols) - plot_dataset_traversals(SmallNorbData(), rel_path=f'plots/traversal-transpose__smallnorb', title=TITLE_TRANSPOSE, offset=0.95, label_size=23, seed=seed, labels_at_top=True, transpose=True, add_random_traversal=False, num_cols=transpose_cols) - plot_dataset_traversals(Cars3dData(), rel_path=f'plots/traversal-transpose__cars3d', title=TITLE_TRANSPOSE, offset=0.95, label_size=23, seed=seed, labels_at_top=True, transpose=True, add_random_traversal=False, num_cols=transpose_cols, take_cols=mini_cols+1) - - # save images - for i in ([1, 2, 3, 4, 5, 6, 7, 8] if all_squares else [1, 2, 4, 8]): - data = XYSquaresData(grid_spacing=i, grid_size=8, no_warnings=True) - plot_dataset_traversals(data, rel_path=f'plots/{prefix}__xy-squares__spacing{i}', title=TITLE, seed=seed-40, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(data, rel_path=f'plots/{prefix}__xy-squares__spacing{i}__some', title=TITLE, seed=seed-40, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols, f_idxs=[0, 3]) - - plot_dataset_traversals(Shapes3dData(), rel_path=f'plots/{prefix}__shapes3d', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(DSpritesData(), rel_path=f'plots/{prefix}__dsprites', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(SmallNorbData(), rel_path=f'plots/{prefix}__smallnorb', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(Cars3dData(), rel_path=f'plots/{prefix}__cars3d', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - - exit(1) - - plot_dataset_traversals(XYObjectData(), rel_path=f'plots/{prefix}__xy-object', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(XYObjectShadedData(), rel_path=f'plots/{prefix}__xy-object-shaded', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(XYBlocksData(), rel_path=f'plots/{prefix}__xy-blocks', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - - plot_dataset_traversals(DSpritesImagenetData(100, 'bg'), rel_path=f'plots/{prefix}__dsprites-imagenet-bg-100', title=TITLE, seed=seed-6, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(DSpritesImagenetData( 50, 'bg'), rel_path=f'plots/{prefix}__dsprites-imagenet-bg-50', title=TITLE, seed=seed-6, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(DSpritesImagenetData(100, 'fg'), rel_path=f'plots/{prefix}__dsprites-imagenet-fg-100', title=TITLE, seed=seed-6, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - plot_dataset_traversals(DSpritesImagenetData( 50, 'fg'), rel_path=f'plots/{prefix}__dsprites-imagenet-fg-50', title=TITLE, seed=seed-6, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - - BASE = os.path.abspath(os.path.join(__file__, '../../../out/adversarial_data_approx')) - - for folder in [ - # 'const' datasets - ('2021-08-18--00-58-22_FINAL-dsprites_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--01-33-47_FINAL-shapes3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--02-20-13_FINAL-cars3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--03-10-53_FINAL-smallnorb_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - # 'invert' datasets - ('2021-08-18--03-52-31_FINAL-dsprites_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--04-29-25_FINAL-shapes3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--05-13-15_FINAL-cars3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('2021-08-18--06-03-32_FINAL-smallnorb_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - # stronger 'invert' datasets - ('2021-09-06--00-29-23_INVERT-VSTRONG-shapes3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('2021-09-06--03-17-28_INVERT-VSTRONG-dsprites_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('2021-09-06--05-42-06_INVERT-VSTRONG-cars3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('2021-09-06--09-10-59_INVERT-VSTRONG-smallnorb_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ]: - plot_dataset_traversals(SelfContainedHdf5GroundTruthData(f'{BASE}/{folder}/data.h5'), rel_path=f'plots/{prefix}__{folder}.png', title=TITLE, seed=seed, add_random_traversal=INCLUDE_RANDOM_TRAVERSAL, num_cols=num_cols) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e00_tuning/submit_param_tuning.sh b/research/e00_tuning/submit_param_tuning.sh deleted file mode 100644 index 4f554dfa..00000000 --- a/research/e00_tuning/submit_param_tuning.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash - - -# OVERVIEW: -# - this experiment is designed to find the optimal hyper-parameters for disentanglement, as well as investigate the -# effect of the adversarial XYSquares dataset against existing approaches. - - -# OUTCOMES: -# - Existing frameworks fail on the adversarial dataset -# - Much lower beta is required for adversarial dataset - - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="CVPR-00__basic-hparam-tuning" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 129600 "C-disent" # 36 hours - -# RUN SWEEP FOR GOOD BETA VALUES -# - beta: 0.01, 0.0316 seem good, 0.1 starts getting too strong, 0.00316 is a bit weak -# - z_size: higher means you can increase beta, eg. 25: beta=0.1 and 9: beta=0.01 -# - framework: adavae needs lower beta, eg. betavae: 0.1, adavae25: 0.0316, adavae9: 0.00316 -# - xy_squares really struggles to learn when non-overlapping, beta needs to be very low. -# might be worth using a warmup schedule -# betavae with zsize=25 and beta<=0.00316 -# betavae with zsize=09 and beta<=0.000316 -# adavae with zsize=25 does not work -# adavae with zsize=09 and beta<=0.001 (must get very lucky) - -# TODO: I should try lower the learning rate to 1e-4 from 1e-3, this might help with xysquares -# 1 * (2 * 8 * 2 * 5) = 160 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_beta' \ - hydra.job.name="vae_hparams" \ - \ - run_length=long \ - metrics=all \ - \ - settings.framework.beta=0.000316,0.001,0.00316,0.01,0.0316,0.1,0.316,1.0 \ - framework=betavae,adavae_os \ - schedule=none \ - settings.model.z_size=9,25 \ - \ - dataset=dsprites,shapes3d,cars3d,smallnorb,X--xysquares \ - sampling=default__bb - - -# TEST DISTANCES IN AEs VS VAEs -# -- supplementary material -# 3 * (1 * 5 = 2) = 15 -submit_sweep \ - +DUMMY.repeat=1,2,3 \ - +EXTRA.tags='sweep_ae' \ - hydra.job.name="ae_test" \ - \ - run_length=medium \ - metrics=all \ - \ - settings.framework.beta=0.0001 \ - framework=ae \ - schedule=none \ - settings.model.z_size=25 \ - \ - dataset=dsprites,shapes3d,cars3d,smallnorb,X--xysquares \ - sampling=default__bb - - -# RUN SWEEP FOR GOOD SCHEDULES -# -- unused -# 1 * (3 * 2 * 4 * 5) = 120 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='sweep_schedule' \ -# \ -# run_length=long \ -# metrics=all \ -# \ -# settings.framework.beta=0.1,0.316,1.0 \ -# framework=betavae,adavae_os \ -# schedule=beta_cyclic,beta_cyclic_slow,beta_cyclic_fast,beta_decrease \ -# settings.model.z_size=25 \ -# \ -# dataset=dsprites,shapes3d,cars3d,smallnorb,X--xysquares \ -# sampling=default__bb diff --git a/research/e01_incr_overlap/run.py b/research/e01_incr_overlap/run.py deleted file mode 100644 index 9a96337a..00000000 --- a/research/e01_incr_overlap/run.py +++ /dev/null @@ -1,72 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -import numpy as np -from disent.dataset.data import XYSquaresData - - -class XYSquaresSampler(XYSquaresData): - - def sample_1d_boxes(self, size=None): - size = (2,) if (size is None) else ((size, 2) if isinstance(size, int) else (*size, 2)) - # sample x0, y0 - s0 = self._offset + self._spacing * np.random.randint(0, self._placements, size=size) - # sample x1, y1 - s1 = s0 + self._square_size - # return (x0, y0), (x1, y1) - return s0, s1 - - def sample_1d_overlap(self, size=None): - s0, s1 = self.sample_1d_boxes(size=size) - # compute overlap - return np.maximum(np.min(s1, axis=-1) - np.max(s0, axis=-1), 0) - - def sample_1d_delta(self, size=None): - s0, s1 = self.sample_1d_boxes(size=size) - # compute differences - l_delta = np.max(s0, axis=-1) - np.min(s0, axis=-1) - r_delta = np.max(s1, axis=-1) - np.min(s1, axis=-1) - # return delta - return np.minimum(l_delta + r_delta, self._square_size * 2) - - -if __name__ == '__main__': - - print('\nDecreasing Spacing & Increasing Size') - for ss, gs in [(8, 8), (9, 7), (17, 6), (25, 5), (33, 4), (41, 3), (49, 2), (57, 1)][::-1]: - d = XYSquaresSampler(square_size=ss, grid_spacing=gs, max_placements=8, no_warnings=True) - print('ss={:2d} gs={:1d} overlap={:7.4f} delta={:7.4f}'.format(ss, gs, d.sample_1d_overlap(size=1_000_000).mean(), d.sample_1d_delta(size=1_000_000).mean())) - - print('\nDecreasing Spacing') - for i in range(8): - ss, gs = 8, 8-i - d = XYSquaresSampler(square_size=ss, grid_spacing=gs, max_placements=8, no_warnings=True) - print('ss={:2d} gs={:1d} overlap={:7.4f} delta={:7.4f}'.format(ss, gs, d.sample_1d_overlap(size=1_000_000).mean(), d.sample_1d_delta(size=1_000_000).mean())) - - print('\nDecreasing Spacing & Keeping Dimension Size Constant') - for i in range(8): - ss, gs = 8, 8-i - d = XYSquaresSampler(square_size=ss, grid_spacing=gs, max_placements=None, no_warnings=True) - print('ss={:2d} gs={:1d} overlap={:7.4f} delta={:7.4f}'.format(ss, gs, d.sample_1d_overlap(size=1_000_000).mean(), d.sample_1d_delta(size=1_000_000).mean())) diff --git a/research/e01_incr_overlap/submit_incr_overlap.sh b/research/e01_incr_overlap/submit_incr_overlap.sh deleted file mode 100644 index 49c28043..00000000 --- a/research/e01_incr_overlap/submit_incr_overlap.sh +++ /dev/null @@ -1,127 +0,0 @@ -#!/bin/bash - - -# OVERVIEW: -# - this experiment is designed to check how increasing overlap (reducing -# the spacing between square positions on XYSquares) affects learning. - - -# OUTCOMES: -# - increasing overlap improves disentanglement & ability for the -# neural network to learn values. -# - decreasing overlap worsens disentanglement, but it also becomes -# very hard for the neural net to learn specific values needed. The -# average image does not correspond well to individual samples. -# Disentanglement performance is also a result of this fact, as -# the network can't always learn the dataset effectively. - - -# ========================================================================= # -# Settings # -# ========================================================================= # - - -export USERNAME="n_michlo" -export PROJECT="CVPR-01__incr_overlap" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - - -# ========================================================================= # -# Experiment # -# ========================================================================= # - - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - - -# background launch various xysquares -# -- original experiment also had dfcvae -# -- beta is too high for adavae -# 5 * (2*2*8 = 32) = 160 -submit_sweep \ - +DUMMY.repeat=1,2,3,4,5 \ - +EXTRA.tags='sweep_xy_squares_overlap' \ - hydra.job.name="incr_ovlp" \ - \ - run_length=medium \ - metrics=all \ - \ - settings.framework.beta=0.001,0.00316 \ - framework=betavae,adavae_os \ - settings.model.z_size=9 \ - \ - sampling=default__bb \ - dataset=X--xysquares_rgb \ - dataset.data.grid_spacing=8,7,6,5,4,3,2,1 - - -# background launch various xysquares -# -- original experiment also had dfcvae -# -- beta is too high for adavae -# 5 * (2*8 = 16) = 80 -submit_sweep \ - +DUMMY.repeat=1,2,3,4,5 \ - +EXTRA.tags='sweep_xy_squares_overlap_small_beta' \ - hydra.job.name="sb_incr_ovlp" \ - \ - run_length=medium \ - metrics=all \ - \ - settings.framework.beta=0.0001,0.00001 \ - framework=adavae_os \ - settings.model.z_size=9 \ - \ - sampling=default__bb \ - dataset=X--xysquares_rgb \ - dataset.data.grid_spacing=8,7,6,5,4,3,2,1 - - -# background launch various xysquares -# - this time we try delay beta, so that it can learn properly... -# - NOTE: this doesn't actually work, the VAE loss often becomes -# NAN because the values are too small. -# 3 * (2*2*8 = 32) = 96 -# submit_sweep \ -# +DUMMY.repeat=1,2,3 \ -# +EXTRA.tags='sweep_xy_squares_overlap_delay' \ -# hydra.job.name="schd_incr_ovlp" \ -# \ -# schedule=beta_delay_long \ -# \ -# run_length=medium \ -# metrics=all \ -# \ -# settings.framework.beta=0.001 \ -# framework=betavae,adavae_os \ -# settings.model.z_size=9,25 \ -# \ -# sampling=default__bb \ -# dataset=X--xysquares_rgb \ -# dataset.data.grid_spacing=8,7,6,5,4,3,2,1 - - -# background launch traditional datasets -# -- original experiment also had dfcvae -# 5 * (2*2*4 = 16) = 80 -#submit_sweep \ -# +DUMMY.repeat=1,2,3,4,5 \ -# +EXTRA.tags='sweep_other' \ -# \ -# run_length=medium \ -# metrics=all \ -# \ -# settings.framework.beta=0.01,0.0316 \ -# framework=betavae,adavae_os \ -# settings.model.z_size=9 \ -# \ -# sampling=default__bb \ -# dataset=cars3d,shapes3d,dsprites,smallnorb - - -# ========================================================================= # -# DONE # -# ========================================================================= # diff --git a/research/e01_visual_overlap/plots/.gitignore b/research/e01_visual_overlap/plots/.gitignore deleted file mode 100644 index e33609d2..00000000 --- a/research/e01_visual_overlap/plots/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.png diff --git a/research/e01_visual_overlap/run_01_x_z_recon_dists.sh b/research/e01_visual_overlap/run_01_x_z_recon_dists.sh deleted file mode 100644 index 5abdc019..00000000 --- a/research/e01_visual_overlap/run_01_x_z_recon_dists.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-01__gt-vs-learnt-dists" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - - -# 1 * (3 * 6 * 4 * 2) = 144 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep' \ - \ - model=linear,vae_fc,vae_conv64 \ - \ - run_length=medium \ - metrics=all \ - \ - dataset=xyobject,xyobject_shaded,shapes3d,dsprites,cars3d,smallnorb \ - sampling=default__bb \ - framework=ae,X--adaae_os,betavae,adavae_os \ - \ - settings.framework.beta=0.0316 \ - settings.optimizer.lr=3e-4 \ - settings.model.z_size=9,25 diff --git a/research/e01_visual_overlap/run_plot_global_dists.py b/research/e01_visual_overlap/run_plot_global_dists.py deleted file mode 100644 index 08f1c239..00000000 --- a/research/e01_visual_overlap/run_plot_global_dists.py +++ /dev/null @@ -1,465 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -import os -from collections import defaultdict -from typing import Dict - -import seaborn as sns -import numpy as np -import pandas as pd -import torch -from matplotlib import pyplot as plt -from matplotlib.ticker import MultipleLocator -from tqdm import tqdm - -import research.util as H -from disent.dataset import DisentDataset -from disent.dataset.data import Cars3dData -from disent.dataset.data import DSpritesData -from disent.dataset.data import Shapes3dData -from disent.dataset.data import XYSquaresData -from disent.dataset.transform import ToImgTensorF32 -from disent.util import to_numpy -from disent.util.function import wrapped_partial - - -# ========================================================================= # -# plot # -# ========================================================================= # - - -def plot_overlap(a, b, mode='abs'): - a, b = np.transpose(to_numpy(a), (1, 2, 0)), np.transpose(to_numpy(b), (1, 2, 0)) - if mode == 'binary': - d = np.float32(a != b) - elif mode == 'abs': - d = np.abs(a - b) - elif mode == 'diff': - d = a - b - else: - raise KeyError - d = (d - d.min()) / (d.max() - d.min()) - a, b, d = np.uint8(a * 255), np.uint8(b * 255), np.uint8(d * 255) - fig, (ax_a, ax_b, ax_d) = plt.subplots(1, 3) - ax_a.imshow(a) - ax_b.imshow(b) - ax_d.imshow(d) - plt.show() - - -# ========================================================================= # -# CORE # -# ========================================================================= # - - -def generate_data(dataset: DisentDataset, data_name: str, batch_size=64, samples=100_000, plot_diffs=False, load_cache=True, save_cache=True, overlap_loss: str = 'mse'): - # cache - file_path = os.path.join(os.path.dirname(__file__), f'cache/{data_name}_{samples}.pkl') - if load_cache: - if os.path.exists(file_path): - print(f'loaded: {file_path}') - return pd.read_pickle(file_path, compression='gzip') - - # generate - with torch.no_grad(): - # dataframe - df = defaultdict(lambda: defaultdict(list)) - - # randomly overlapped data - name = 'random' - for i in tqdm(range((samples + (batch_size-1) - 1) // (batch_size-1)), desc=f'{data_name}: {name}'): - # get random batch of unique elements - idxs = H.sample_unique_batch_indices(num_obs=len(dataset), num_samples=batch_size) - batch = dataset.dataset_batch_from_indices(idxs, mode='input') - # plot - if plot_diffs and (i == 0): - plot_overlap(batch[0], batch[1]) - # store overlap results - o = to_numpy(H.pairwise_overlap(batch[:-1], batch[1:], mode=overlap_loss)) - df[True][name].extend(o) - df[False][name].extend(o) - - # traversal overlaps - for f_idx in range(dataset.gt_data.num_factors): - name = f'f_{dataset.gt_data.factor_names[f_idx]}' - for i in tqdm(range((samples + (dataset.gt_data.factor_sizes[f_idx] - 1) - 1) // (dataset.gt_data.factor_sizes[f_idx] - 1)), desc=f'{data_name}: {name}'): - # get random batch that is a factor traversal - factors = dataset.gt_data.sample_random_factor_traversal(f_idx) - batch = dataset.dataset_batch_from_factors(factors, mode='input') - # shuffle indices - idxs = np.arange(len(factors)) - np.random.shuffle(idxs) - # plot - if plot_diffs and (i == 0): plot_overlap(batch[0], batch[1]) - # store overlap results - df[True][name].extend(to_numpy(H.pairwise_overlap(batch[:-1], batch[1:], mode=overlap_loss))) - df[False][name].extend(to_numpy(H.pairwise_overlap(batch[idxs[:-1]], batch[idxs[1:]], mode=overlap_loss))) - - # make dataframe! - df = pd.DataFrame({ - 'overlap': [d for ordered, data in df.items() for name, dat in data.items() for d in dat], - 'samples': [name for ordered, data in df.items() for name, dat in data.items() for d in dat], - 'ordered': [ordered for ordered, data in df.items() for name, dat in data.items() for d in dat], - 'data': [data_name for ordered, data in df.items() for name, dat in data.items() for d in dat], - }) - - # save into cache - if save_cache: - os.makedirs(os.path.dirname(file_path), exist_ok=True) - df.to_pickle(file_path, compression='gzip') - print(f'cached: {file_path}') - - return df - - -# ========================================================================= # -# plotting # -# ========================================================================= # - - -def dual_plot_from_generated_data(df: pd.DataFrame, data_name: str = None, save_name: str = None, tick_size: float = None, fig_l_pad=1, fig_w=7, fig_h=13): - # make subplots - cm = 1 / 2.54 - fig, (ax0, ax1) = plt.subplots(1, 2, figsize=((fig_l_pad+2*fig_w)*cm, fig_h*cm)) - if data_name is not None: - fig.suptitle(data_name, fontsize=20) - ax0.set_ylim(-0.025, 1.025) - ax1.set_ylim(-0.025, 1.025) - # plot - ax0.set_title('Ordered Traversals') - sns.ecdfplot(ax=ax0, data=df[df['ordered']==True], x="distance", hue="samples") - ax1.set_title('Shuffled Traversals') - sns.ecdfplot(ax=ax1, data=df[df['ordered']==False], x="distance", hue="samples") - # edit plots - ax0.set_xlabel('Visual Distance') - ax1.set_xlabel('Visual Distance') - if tick_size is not None: - ax0.xaxis.set_major_locator(MultipleLocator(base=tick_size)) - ax1.xaxis.set_major_locator(MultipleLocator(base=tick_size)) - # ax0.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) - # ax1.xaxis.set_major_formatter(FormatStrFormatter('%.2f')) - ax0.set_ylabel('Cumulative Proportion') - ax1.set_ylabel(None) - ax1.set_yticklabels([]) - ax1.get_legend().remove() - plt.tight_layout() - # save - if save_name is not None: - path = os.path.join(os.path.dirname(__file__), 'plots', save_name) - os.makedirs(os.path.dirname(path), exist_ok=True) - plt.savefig(path) - print(f'saved: {path}') - # show - return fig - - -def all_plot_from_all_generated_data(dfs: dict, ordered=True, save_name: str = None, tick_sizes: Dict[str, float] = None, hide_extra_legends=False, fig_l_pad=1, fig_w=7, fig_h=13): - if not dfs: - return None - # make subplots - cm = 1 / 2.54 - fig, axs = plt.subplots(1, len(dfs), figsize=((fig_l_pad+len(dfs)*fig_w)*cm, fig_h * cm)) - axs = np.array(axs, dtype=np.object).reshape((-1,)) - # plot all - for i, (ax, (data_name, df)) in enumerate(zip(axs, dfs.items())): - # plot - ax.set_title(data_name) - sns.ecdfplot(ax=ax, data=df[df['ordered']==ordered], x="distance", hue="samples") - # edit plots - ax.set_ylim(-0.025, 1.025) - ax.set_xlabel('Visual Distance') - if (tick_sizes is not None) and (data_name in tick_sizes): - ax.xaxis.set_major_locator(MultipleLocator(base=tick_sizes[data_name])) - if i == 0: - ax.set_ylabel('Cumulative Proportion') - else: - if hide_extra_legends: - ax.get_legend().remove() - ax.set_ylabel(None) - ax.set_yticklabels([]) - plt.tight_layout() - # save - if save_name is not None: - path = os.path.join(os.path.dirname(__file__), 'plots', save_name) - os.makedirs(os.path.dirname(path), exist_ok=True) - plt.savefig(path) - print(f'saved: {path}') - # show - return fig - - -def plot_all(exp_name: str, gt_data_classes, tick_sizes: dict, samples: int, load=True, save=True, show_plt=True, show_dual_plt=False, save_plt=True, hide_extra_legends=False, fig_l_pad=1, fig_w=7, fig_h=13): - # generate data and plot! - dfs = {} - for data_name, data_cls in gt_data_classes.items(): - df = generate_data( - DisentDataset(data_cls(), transform=ToImgTensorF32()), - data_name, - batch_size=64, - samples=samples, - plot_diffs=False, - load_cache=load, - save_cache=save, - ) - dfs[data_name] = df - # flip overlap - df['distance'] = - df['overlap'] - del df['overlap'] - # plot ordered + shuffled - fig = dual_plot_from_generated_data( - df, - data_name=data_name, - save_name=f'{exp_name}/{data_name}_{samples}.png' if save_plt else None, - tick_size=tick_sizes.get(data_name, None), - fig_l_pad=fig_l_pad, - fig_w=fig_w, - fig_h=fig_h, - ) - - if show_dual_plt: - plt.show() - else: - plt.close(fig) - - def _all_plot_generated(dfs, ordered: bool, suffix: str): - fig = all_plot_from_all_generated_data( - dfs, - ordered=ordered, - save_name=f'{exp_name}/{exp_name}-{"ordered" if ordered else "shuffled"}{suffix}.png' if save_plt else None, - tick_sizes=tick_sizes, - hide_extra_legends=hide_extra_legends, - fig_l_pad=fig_l_pad, - fig_w=fig_w, - fig_h=fig_h, - ) - if show_plt: - plt.show() - else: - plt.close(fig) - - # all ordered plots - _all_plot_generated(dfs, ordered=True, suffix='') - _all_plot_generated({k: v for k, v in dfs.items() if k.lower().startswith('xy')}, ordered=True, suffix='-xy') - _all_plot_generated({k: v for k, v in dfs.items() if not k.lower().startswith('xy')}, ordered=True, suffix='-normal') - # all shuffled plots - _all_plot_generated(dfs, ordered=False, suffix='') - _all_plot_generated({k: v for k, v in dfs.items() if k.lower().startswith('xy')}, ordered=False, suffix='-xy') - _all_plot_generated({k: v for k, v in dfs.items() if not k.lower().startswith('xy')}, ordered=False, suffix='-normal') - # done! - return dfs - - -def plot_dfs_stacked(dfs, title: str, save_name: str = None, show_plt=True, tick_size: float = None, fig_l_pad=1, fig_w=7, fig_h=13, **kwargs): - # make new dataframe - df = pd.concat((df[df['samples']=='random'] for df in dfs.values())) - # make plot - cm = 1 / 2.54 - fig, ax = plt.subplots(1, 1, figsize=((fig_l_pad+1*fig_w)*cm, fig_h*cm)) - ax.set_title(title) - # plot - # sns.kdeplot(ax=ax, data=df, x="overlap", hue="data", bw_adjust=2) - sns.ecdfplot(ax=ax, data=df, x="overlap", hue="data") - # edit settins - # ax.set_ylim(-0.025, 1.025) - ax.set_xlabel('Overlap') - if tick_size is not None: - ax.xaxis.set_major_locator(MultipleLocator(base=tick_size)) - ax.set_ylabel('Cumulative Proportion') - plt.tight_layout() - # save - if save_name is not None: - path = os.path.join(os.path.dirname(__file__), 'plots', save_name) - os.makedirs(os.path.dirname(path), exist_ok=True) - plt.savefig(path) - print(f'saved: {path}') - # show - if show_plt: - plt.show() - else: - plt.close(fig) - - -def plot_unique_count(dfs, save_name: str = None, show_plt: bool = True, fig_l_pad=1, fig_w=1.5*7, fig_h=13): - df_uniques = pd.DataFrame({ - 'Grid Spacing': ['/'.join(data_name.split('-')[1:]) for data_name, df in dfs.items()], - 'Unique Overlap Values': [len(np.unique(df['overlap'].values, return_counts=True)[1]) for data_name, df in dfs.items()] - }) - # make plot - cm = 1 / 2.54 - fig, ax = plt.subplots(1, 1, figsize=((fig_l_pad+fig_w)*cm, fig_h*cm)) - ax.set_title('Increasing Overlap') - sns.barplot(data=df_uniques, x='Grid Spacing', y='Unique Overlap Values') - plt.gca().invert_xaxis() - plt.tight_layout() - # save - if save_name is not None: - path = os.path.join(os.path.dirname(__file__), 'plots', save_name) - os.makedirs(os.path.dirname(path), exist_ok=True) - plt.savefig(path) - print(f'saved: {path}') - # show - if show_plt: - plt.show() - else: - plt.close(fig) - - -# ========================================================================= # -# entrypoint # -# ========================================================================= # - - -if __name__ == '__main__': - - # TODO: update to new classes - # TODO: update to use registry - - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # common settings - SHARED_SETTINGS = dict( - samples=50_000, - load=True, - save=True, - show_plt=True, - save_plt=True, - show_dual_plt=False, - fig_l_pad=1, - fig_w=5.5, - fig_h=13, - tick_sizes={ - 'DSprites': 0.05, - 'Shapes3d': 0.2, - 'Cars3d': 0.05, - 'XYSquares': 0.01, - # increasing levels of overlap - 'XYSquares-1': 0.01, - 'XYSquares-2': 0.01, - 'XYSquares-3': 0.01, - 'XYSquares-4': 0.01, - 'XYSquares-5': 0.01, - 'XYSquares-6': 0.01, - 'XYSquares-7': 0.01, - 'XYSquares-8': 0.01, - # increasing levels of overlap 2 - 'XYSquares-1-8': 0.01, - 'XYSquares-2-8': 0.01, - 'XYSquares-3-8': 0.01, - 'XYSquares-4-8': 0.01, - 'XYSquares-5-8': 0.01, - 'XYSquares-6-8': 0.01, - 'XYSquares-7-8': 0.01, - 'XYSquares-8-8': 0.01, - }, - ) - - # EXPERIMENT 0 -- visual overlap on existing datasets - - dfs = plot_all( - exp_name='dataset-overlap', - gt_data_classes={ - # 'XYObject': wrapped_partial(XYObjectData), - # 'XYBlocks': wrapped_partial(XYBlocksData), - 'XYSquares': wrapped_partial(XYSquaresData), - 'DSprites': wrapped_partial(DSpritesData), - 'Shapes3d': wrapped_partial(Shapes3dData), - 'Cars3d': wrapped_partial(Cars3dData), - # 'SmallNorb': wrapped_partial(SmallNorbData), - # 'Mpi3d': wrapped_partial(Mpi3dData), - }, - hide_extra_legends=False, - **SHARED_SETTINGS - ) - - # EXPERIMENT 1 -- increasing visual overlap - - dfs = plot_all( - exp_name='increasing-overlap', - gt_data_classes={ - 'XYSquares-1': wrapped_partial(XYSquaresData, grid_spacing=1), - 'XYSquares-2': wrapped_partial(XYSquaresData, grid_spacing=2), - 'XYSquares-3': wrapped_partial(XYSquaresData, grid_spacing=3), - 'XYSquares-4': wrapped_partial(XYSquaresData, grid_spacing=4), - 'XYSquares-5': wrapped_partial(XYSquaresData, grid_spacing=5), - 'XYSquares-6': wrapped_partial(XYSquaresData, grid_spacing=6), - 'XYSquares-7': wrapped_partial(XYSquaresData, grid_spacing=7), - 'XYSquares-8': wrapped_partial(XYSquaresData, grid_spacing=8), - }, - hide_extra_legends=True, - **SHARED_SETTINGS - ) - - plot_unique_count( - dfs=dfs, - save_name='increasing-overlap/xysquares-increasing-overlap-counts.png', - ) - - plot_dfs_stacked( - dfs=dfs, - title='Increasing Overlap', - exp_name='increasing-overlap', - save_name='increasing-overlap/xysquares-increasing-overlap.png', - tick_size=0.01, - fig_w=13 - ) - - # EXPERIMENT 2 -- increasing visual overlap fixed dim size - - dfs = plot_all( - exp_name='increasing-overlap-fixed', - gt_data_classes={ - 'XYSquares-1-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=1, grid_size=8), - 'XYSquares-2-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=2, grid_size=8), - 'XYSquares-3-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=3, grid_size=8), - 'XYSquares-4-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=4, grid_size=8), - 'XYSquares-5-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=5, grid_size=8), - 'XYSquares-6-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=6, grid_size=8), - 'XYSquares-7-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=7, grid_size=8), - 'XYSquares-8-8': wrapped_partial(XYSquaresData, square_size=8, grid_spacing=8, grid_size=8), - }, - hide_extra_legends=True, - **SHARED_SETTINGS - ) - - plot_unique_count( - dfs=dfs, - save_name='increasing-overlap-fixed/xysquares-increasing-overlap-fixed-counts.png', - ) - - plot_dfs_stacked( - dfs=dfs, - title='Increasing Overlap', - exp_name='increasing-overlap-fixed', - save_name='increasing-overlap-fixed/xysquares-increasing-overlap-fixed.png', - tick_size=0.01, - fig_w=13 - ) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e01_visual_overlap/run_plot_traversal_dists.py b/research/e01_visual_overlap/run_plot_traversal_dists.py deleted file mode 100644 index 56283d35..00000000 --- a/research/e01_visual_overlap/run_plot_traversal_dists.py +++ /dev/null @@ -1,654 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import os -from collections import defaultdict -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Literal -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Union - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn.functional as F -from tqdm import tqdm - -import research.util as H -from disent.dataset.data import GroundTruthData -from disent.dataset.data import SelfContainedHdf5GroundTruthData -from disent.dataset.util.state_space import NonNormalisedFactors -from disent.dataset.transform import ToImgTensorF32 -from disent.dataset.util.stats import compute_data_mean_std -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.profiling import Timer -from disent.util.seeds import TempNumpySeed - - -# ========================================================================= # -# Factor Traversal Stats # -# ========================================================================= # - - -SampleModeHint = Union[Literal['random'], Literal['near'], Literal['combinations']] - - -@torch.no_grad() -def sample_factor_traversal_info( - gt_data: GroundTruthData, - f_idx: Optional[int] = None, - circular_distance: bool = False, - sample_mode: SampleModeHint = 'random', -) -> dict: - # load traversal -- TODO: this is the bottleneck! not threaded - factors, indices, obs = gt_data.sample_random_obs_traversal(f_idx=f_idx, obs_collect_fn=torch.stack) - # get pairs - idxs_a, idxs_b = H.pair_indices(max_idx=len(indices), mode=sample_mode) - # compute deltas - deltas = F.mse_loss(obs[idxs_a], obs[idxs_b], reduction='none').mean(dim=[-3, -2, -1]).numpy() - fdists = H.np_factor_dists(factors[idxs_a], factors[idxs_b], factor_sizes=gt_data.factor_sizes, circular_if_factor_sizes=circular_distance, p=1) - # done! - return dict( - # traversals - factors=factors, # np.ndarray - indices=indices, # np.ndarray - obs=obs, # torch.Tensor - # pairs - idxs_a=idxs_a, # np.ndarray - idxs_b=idxs_b, # np.ndarray - deltas=deltas, # np.ndarray - fdists=fdists, # np.ndarray - ) - - -def sample_factor_traversal_info_and_distmat( - gt_data: GroundTruthData, - f_idx: Optional[int] = None, - circular_distance: bool = False, -) -> dict: - dat = sample_factor_traversal_info(gt_data=gt_data, f_idx=f_idx, sample_mode='combinations', circular_distance=circular_distance) - # extract - factors, idxs_a, idxs_b, deltas, fdists = dat['factors'], dat['idxs_a'], dat['idxs_b'], dat['deltas'], dat['fdists'] - # generate deltas matrix - deltas_matrix = np.zeros([factors.shape[0], factors.shape[0]]) - deltas_matrix[idxs_a, idxs_b] = deltas - deltas_matrix[idxs_b, idxs_a] = deltas - # generate distance matrix - fdists_matrix = np.zeros([factors.shape[0], factors.shape[0]]) - fdists_matrix[idxs_a, idxs_b] = fdists - fdists_matrix[idxs_b, idxs_a] = fdists - # done! - return dict(**dat, deltas_matrix=deltas_matrix, fdists_matrix=fdists_matrix) - - -# ========================================================================= # -# Factor Traversal Collector # -# ========================================================================= # - - -def _collect_stats_for_factors( - gt_data: GroundTruthData, - f_idxs: Sequence[int], - stats_fn: Callable[[GroundTruthData, int, int], Dict[str, Any]], - keep_keys: Sequence[str], - stats_callback: Optional[Callable[[Dict[str, List[Any]], int, int], None]] = None, - return_stats: bool = True, - num_traversal_sample: int = 100, -) -> List[Dict[str, List[Any]]]: - # prepare - f_idxs = gt_data.normalise_factor_idxs(f_idxs) - # generate data per factor - f_stats = [] - for i, f_idx in enumerate(f_idxs): - factor_name = gt_data.factor_names[f_idx] - factor_size = gt_data.factor_sizes[f_idx] - # repeatedly generate stats per factor - stats = defaultdict(list) - for _ in tqdm(range(num_traversal_sample), desc=f'{gt_data.name}: {factor_name}'): - data = stats_fn(gt_data, i, f_idx) - for key in keep_keys: - stats[key].append(data[key]) - # save factor stats - if return_stats: - f_stats.append(stats) - if stats_callback: - stats_callback(stats, i, f_idx) - # done! - if return_stats: - return f_stats - - -# ========================================================================= # -# Plot Traversal Stats # -# ========================================================================= # - - -_COLORS = { - 'blue': (None, 'Blues', 'Blues'), - 'red': (None, 'Reds', 'Reds'), - 'purple': (None, 'Purples', 'Purples'), - 'green': (None, 'Greens', 'Greens'), - 'orange': (None, 'Oranges', 'Oranges'), -} - - -def plot_traversal_stats( - dataset_or_name: Union[str, GroundTruthData], - num_repeats: int = 256, - f_idxs: Optional[NonNormalisedFactors] = None, - circular_distance: bool = False, - color='blue', - color_gt_dist='blue', - color_im_dist='purple', - suffix: Optional[str] = None, - save_path: Optional[str] = None, - plot_freq: bool = True, - plot_title: Union[bool, str] = False, - fig_block_size: float = 4.0, - col_titles: Union[bool, List[str]] = True, - hide_axis: bool = True, - hide_labels: bool = True, - y_size_offset: float = 0.0, - x_size_offset: float = 0.0, -): - # - - - - - - - - - - - - - - - - - # - - def stats_fn(gt_data, i, f_idx): - return sample_factor_traversal_info_and_distmat(gt_data=gt_data, f_idx=f_idx, circular_distance=circular_distance) - - def plot_ax(stats: dict, i: int, f_idx: int): - deltas = np.concatenate(stats['deltas']) - fdists = np.concatenate(stats['fdists']) - fdists_matrix = np.mean(stats['fdists_matrix'], axis=0) - deltas_matrix = np.mean(stats['deltas_matrix'], axis=0) - - # ensure that if we limit the number of points, that we get good values - with TempNumpySeed(777): np.random.shuffle(deltas) - with TempNumpySeed(777): np.random.shuffle(fdists) - - # subplot! - if plot_freq: - ax0, ax1, ax2, ax3 = axs[:, i] - else: - (ax0, ax1), (ax2, ax3) = (None, None), axs[:, i] - - # get title - curr_title = None - if isinstance(col_titles, bool): - if col_titles: - curr_title = gt_data.factor_names[f_idx] - else: - curr_title = col_titles[i] - - # set column titles - if curr_title is not None: - (ax0 if plot_freq else ax2).set_title(f'{curr_title}\n', fontsize=24) - - # plot the frequency stuffs - if plot_freq: - ax0.violinplot([deltas], vert=False) - ax0.set_xlabel('deltas') - ax0.set_ylabel('proportion') - - ax1.set_title('deltas vs. fdists') - ax1.scatter(x=deltas[:15_000], y=fdists[:15_000], s=20, alpha=0.1, c=c_points) - H.plt_2d_density( - x=deltas[:10_000], xmin=deltas.min(), xmax=deltas.max(), - y=fdists[:10_000], ymin=fdists.min() - 0.5, ymax=fdists.max() + 0.5, - n_bins=100, - ax=ax1, pcolormesh_kwargs=dict(cmap=cmap_density, alpha=0.5), - ) - ax1.set_xlabel('deltas') - ax1.set_ylabel('fdists') - - # ax2.set_title('fdists') - ax2.imshow(fdists_matrix, cmap=gt_cmap_img) - if not hide_labels: ax2.set_xlabel('f_idx') - if not hide_labels: ax2.set_ylabel('f_idx') - if hide_axis: H.plt_hide_axis(ax2) - - # ax3.set_title('divergence') - ax3.imshow(deltas_matrix, cmap=im_cmap_img) - if not hide_labels: ax3.set_xlabel('f_idx') - if not hide_labels: ax3.set_ylabel('f_idx') - if hide_axis: H.plt_hide_axis(ax3) - - - # - - - - - - - - - - - - - - - - - # - - # initialize - gt_data: GroundTruthData = H.make_data(dataset_or_name) if isinstance(dataset_or_name, str) else dataset_or_name - f_idxs = gt_data.normalise_factor_idxs(f_idxs) - - c_points, cmap_density, cmap_img = _COLORS[color] - im_c_points, im_cmap_density, im_cmap_img = _COLORS[color if (color_im_dist is None) else color_im_dist] - gt_c_points, gt_cmap_density, gt_cmap_img = _COLORS[color if (color_gt_dist is None) else color_gt_dist] - - n = 4 if plot_freq else 2 - - # get additional spacing - title_offset = 0 if (isinstance(col_titles, bool) and not col_titles) else 0.15 - - # settings - r, c = [n, len(f_idxs)] - h, w = [(n+title_offset)*fig_block_size + y_size_offset, len(f_idxs)*fig_block_size + x_size_offset] - - # initialize plot - fig, axs = plt.subplots(r, c, figsize=(w, h), squeeze=False) - - if isinstance(plot_title, str): - fig.suptitle(f'{plot_title}\n', fontsize=25) - elif plot_title: - fig.suptitle(f'{gt_data.name} [circular={circular_distance}]{f" {suffix}" if suffix else ""}\n', fontsize=25) - - # generate plot - _collect_stats_for_factors( - gt_data=gt_data, - f_idxs=f_idxs, - stats_fn=stats_fn, - keep_keys=['deltas', 'fdists', 'deltas_matrix', 'fdists_matrix'], - stats_callback=plot_ax, - num_traversal_sample=num_repeats, - ) - - # finalize plot - fig.tight_layout() # (pad=1.4 if hide_labels else 1.08) - - # save the path - if save_path is not None: - assert save_path.endswith('.png') - ensure_parent_dir_exists(save_path) - plt.savefig(save_path) - print(f'saved {gt_data.name} to: {save_path}') - - # show it! - plt.show() - - # - - - - - - - - - - - - - - - - - # - return fig - - -# TODO: fix -def plot_traversal_stats( - dataset_or_name: Union[str, GroundTruthData], - num_repeats: int = 256, - f_idxs: Optional[NonNormalisedFactors] = None, - circular_distance: bool = False, - color='blue', - color_gt_dist='blue', - color_im_dist='purple', - suffix: Optional[str] = None, - save_path: Optional[str] = None, - plot_freq: bool = True, - plot_title: Union[bool, str] = False, - plt_scale: float = 6, - col_titles: Union[bool, List[str]] = True, - hide_axis: bool = True, - hide_labels: bool = True, - y_size_offset: float = 0.45, - x_size_offset: float = 0.75, - disable_labels: bool = False, - bottom_labels: bool = False, - label_size: int = 23, -): - # - - - - - - - - - - - - - - - - - # - - def stats_fn(gt_data, i, f_idx): - return sample_factor_traversal_info_and_distmat( - gt_data=gt_data, f_idx=f_idx, circular_distance=circular_distance - ) - - grid_t = [] - grid_titles = [] - - def plot_ax(stats: dict, i: int, f_idx: int): - fdists_matrix = np.mean(stats['fdists_matrix'], axis=0) - deltas_matrix = np.mean(stats['deltas_matrix'], axis=0) - grid_t.append([fdists_matrix, deltas_matrix]) - # get the title - if isinstance(col_titles, bool): - if col_titles: - grid_titles.append(gt_data.factor_names[f_idx]) - else: - grid_titles.append(col_titles[i]) - - # initialize - gt_data: GroundTruthData = H.make_data(dataset_or_name) if isinstance(dataset_or_name, str) else dataset_or_name - f_idxs = gt_data.normalise_factor_idxs(f_idxs) - - # get title - if isinstance(plot_title, str): - suptitle = f'{plot_title}' - elif plot_title: - suptitle = f'{gt_data.name} [circular={circular_distance}]{f" {suffix}" if suffix else ""}' - else: - suptitle = None - - # generate plot - _collect_stats_for_factors( - gt_data=gt_data, - f_idxs=f_idxs, - stats_fn=stats_fn, - keep_keys=['deltas', 'fdists', 'deltas_matrix', 'fdists_matrix'], - stats_callback=plot_ax, - num_traversal_sample=num_repeats, - ) - - labels = None - if (not disable_labels) and grid_titles: - labels = grid_titles - - # settings - fig, axs = H.plt_subplots_imshow( - grid=list(zip(*grid_t)), - title=suptitle, - titles=None if bottom_labels else labels, - titles_size=label_size, - col_labels=labels if bottom_labels else None, - label_size=label_size, - subplot_padding=None, - figsize=((1/2.54) * len(f_idxs) * plt_scale + x_size_offset, (1/2.54) * (2) * plt_scale + y_size_offset) - ) - - # recolor axes - for (ax0, ax1) in axs.T: - ax0.images[0].set_cmap('Blues') - ax1.images[0].set_cmap('Purples') - - fig.tight_layout() - - # save the path - if save_path is not None: - assert save_path.endswith('.png') - ensure_parent_dir_exists(save_path) - plt.savefig(save_path) - print(f'saved {gt_data.name} to: {save_path}') - - # show it! - plt.show() - - # - - - - - - - - - - - - - - - - - # - return fig - - -# ========================================================================= # -# MAIN - DISTS # -# ========================================================================= # - - -@torch.no_grad() -def factor_stats(gt_data: GroundTruthData, f_idxs=None, min_samples: int = 100_000, min_repeats: int = 5000, recon_loss: str = 'mse', sample_mode: str = 'random') -> Tuple[Sequence[int], List[np.ndarray]]: - from disent.registry import RECON_LOSSES - from disent.frameworks.helper.reconstructions import ReconLossHandler - recon_loss: ReconLossHandler = RECON_LOSSES[recon_loss](reduction='mean') - - f_dists = [] - f_idxs = gt_data.normalise_factor_idxs(f_idxs) - # for each factor - for f_idx in f_idxs: - dists = [] - with tqdm(desc=gt_data.factor_names[f_idx], total=min_samples) as p: - # for multiple random factor traversals along the factor - while len(dists) < min_samples or p.n < min_repeats: - # based on: sample_factor_traversal_info(...) # TODO: should add recon loss to that function instead - factors, indices, obs = gt_data.sample_random_obs_traversal(f_idx=f_idx, obs_collect_fn=torch.stack) - # random pairs -- we use this because it does not include [i == i] - idxs_a, idxs_b = H.pair_indices(max_idx=len(indices), mode=sample_mode) - # get distances - d = recon_loss.compute_pairwise_loss(obs[idxs_a], obs[idxs_b]) - d = d.numpy().tolist() - # H.plt_subplots_imshow([[np.moveaxis(o.numpy(), 0, -1) for o in obs]]) - # plt.show() - dists.extend(d) - p.update(len(d)) - # aggregate the average distances - f_dists.append(np.array(dists)[:min_samples]) - - return f_idxs, f_dists - - -def get_random_dists(gt_data: GroundTruthData, num_samples: int = 100_000, recon_loss: str = 'mse'): - from disent.registry import RECON_LOSSES - from disent.frameworks.helper.reconstructions import ReconLossHandler - recon_loss: ReconLossHandler = RECON_LOSSES[recon_loss](reduction='mean') - - dists = [] - with tqdm(desc=gt_data.name, total=num_samples) as p: - # for multiple random factor traversals along the factor - while len(dists) < num_samples: - # random pair - i, j = np.random.randint(0, len(gt_data), size=2) - # get distance - d = recon_loss.compute_pairwise_loss(gt_data[i][None, ...], gt_data[j][None, ...]) - # plt.show() - dists.append(float(d.flatten())) - p.update() - # done! - return np.array(dists) - - -def print_ave_dists(gt_data: GroundTruthData, num_samples: int = 100_000, recon_loss: str = 'mse'): - dists = get_random_dists(gt_data=gt_data, num_samples=num_samples, recon_loss=recon_loss) - f_mean = np.mean(dists) - f_std = np.std(dists) - print(f'[{gt_data.name}] RANDOM ({len(gt_data)}, {len(dists)}) - mean: {f_mean:7.4f} std: {f_std:7.4f}') - - -def print_ave_factor_stats(gt_data: GroundTruthData, f_idxs=None, min_samples: int = 100_000, min_repeats: int = 5000, recon_loss: str = 'mse', sample_mode: str = 'random'): - # compute average distances - f_idxs, f_dists = factor_stats(gt_data=gt_data, f_idxs=f_idxs, min_repeats=min_repeats, min_samples=min_samples, recon_loss=recon_loss, sample_mode=sample_mode) - # compute dists - f_means = [np.mean(d) for d in f_dists] - f_stds = [np.std(d) for d in f_dists] - # sort in order of importance - order = np.argsort(f_means)[::-1] - # print information - for i in order: - f_idx, f_mean, f_std = f_idxs[i], f_means[i], f_stds[i] - print(f'[{gt_data.name}] {gt_data.factor_names[f_idx]} ({gt_data.factor_sizes[f_idx]}, {len(f_dists[f_idx])}) - mean: {f_mean:7.4f} std: {f_std:7.4f}') - - -def main_compute_dists(factor_samples: int = 50_000, min_repeats: int = 5000, random_samples: int = 50_000, recon_loss: str = 'mse', sample_mode: str = 'random', seed: int = 777): - # plot standard datasets - for name in ['dsprites', 'shapes3d', 'cars3d', 'smallnorb', 'xysquares_8x8_s8']: - gt_data = H.make_data(name) - if factor_samples is not None: - with TempNumpySeed(seed): - print_ave_factor_stats(gt_data, min_samples=factor_samples, min_repeats=min_repeats, recon_loss=recon_loss, sample_mode=sample_mode) - if random_samples is not None: - with TempNumpySeed(seed): - print_ave_dists(gt_data, num_samples=random_samples, recon_loss=recon_loss) - -# [dsprites] position_y (32, 50000) - mean: 0.0584 std: 0.0378 -# [dsprites] position_x (32, 50000) - mean: 0.0559 std: 0.0363 -# [dsprites] scale (6, 50000) - mean: 0.0250 std: 0.0148 -# [dsprites] shape (3, 50000) - mean: 0.0214 std: 0.0095 -# [dsprites] orientation (40, 50000) - mean: 0.0172 std: 0.0106 -# [dsprites] RANDOM (737280, 50000) - mean: 0.0754 std: 0.0289 - -# [3dshapes] wall_hue (10, 50000) - mean: 0.1122 std: 0.0661 -# [3dshapes] floor_hue (10, 50000) - mean: 0.1086 std: 0.0623 -# [3dshapes] object_hue (10, 50000) - mean: 0.0416 std: 0.0292 -# [3dshapes] shape (4, 50000) - mean: 0.0207 std: 0.0161 -# [3dshapes] scale (8, 50000) - mean: 0.0182 std: 0.0153 -# [3dshapes] orientation (15, 50000) - mean: 0.0116 std: 0.0079 -# [3dshapes] RANDOM (480000, 50000) - mean: 0.2432 std: 0.0918 - -# [cars3d] azimuth (24, 50000) - mean: 0.0355 std: 0.0185 -# [cars3d] object_type (183, 50000) - mean: 0.0349 std: 0.0176 -# [cars3d] elevation (4, 50000) - mean: 0.0174 std: 0.0100 -# [cars3d] RANDOM (17568, 50000) - mean: 0.0519 std: 0.0188 - -# [smallnorb] lighting (6, 50000) - mean: 0.0531 std: 0.0563 -# [smallnorb] category (5, 50000) - mean: 0.0113 std: 0.0066 -# [smallnorb] rotation (18, 50000) - mean: 0.0090 std: 0.0071 -# [smallnorb] instance (5, 50000) - mean: 0.0068 std: 0.0048 -# [smallnorb] elevation (9, 50000) - mean: 0.0034 std: 0.0030 -# [smallnorb] RANDOM (24300, 50000) - mean: 0.0535 std: 0.0529 - -# [xy_squares] y_B (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] x_B (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] y_G (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] x_G (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] y_R (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] x_R (8, 50000) - mean: 0.0104 std: 0.0000 -# [xy_squares] RANDOM (262144, 50000) - mean: 0.0308 std: 0.0022 - -# ========================================================================= # -# MAIN - PLOTTING # -# ========================================================================= # - - -def _make_self_contained_dataset(h5_path): - return SelfContainedHdf5GroundTruthData(h5_path=h5_path, transform=ToImgTensorF32()) - - -def _print_data_mean_std(data_or_name, print_mean_std: bool = True): - if print_mean_std: - data = H.make_data(data_or_name) if isinstance(data_or_name, str) else data_or_name - name = data_or_name if isinstance(data_or_name, str) else data.name - mean, std = compute_data_mean_std(data) - print(f'{name}\n vis_mean: {mean.tolist()}\n vis_std: {std.tolist()}') - - -def main_plotting(plot_all=False, print_mean_std=False): - CIRCULAR = False - PLOT_FREQ = False - - def sp(name): - prefix = 'CIRCULAR_' if CIRCULAR else 'DIST_' - prefix = prefix + ('FREQ_' if PLOT_FREQ else 'NO-FREQ_') - return os.path.join(os.path.dirname(__file__), 'plots', f'{prefix}{name}.png') - - # plot xysquares with increasing overlap - for s in [1, 2, 3, 4, 5, 6, 7, 8]: - plot_traversal_stats(circular_distance=CIRCULAR, plt_scale=8, label_size=26, x_size_offset=0, y_size_offset=0.6, save_path=sp(f'xysquares_8x8_s{s}'), color='blue', dataset_or_name=f'xysquares_8x8_s{s}', f_idxs=[1], col_titles=[f'Space: {s}px'], plot_freq=PLOT_FREQ) - _print_data_mean_std(f'xysquares_8x8_s{s}', print_mean_std) - - # plot standard datasets - for name in ['dsprites', 'shapes3d', 'cars3d', 'smallnorb']: - plot_traversal_stats(circular_distance=CIRCULAR, x_size_offset=0, y_size_offset=0.6, num_repeats=256, disable_labels=False, save_path=sp(name), color='blue', dataset_or_name=name, plot_freq=PLOT_FREQ) - _print_data_mean_std(name, print_mean_std) - - if not plot_all: - return - - # plot adversarial dsprites datasets - for fg in [True, False]: - for vis in [100, 80, 60, 40, 20]: - name = f'dsprites_imagenet_{"fg" if fg else "bg"}_{vis}' - plot_traversal_stats(circular_distance=CIRCULAR, save_path=sp(name), color='orange', dataset_or_name=name, plot_freq=PLOT_FREQ, x_size_offset=0.4) - _print_data_mean_std(name, print_mean_std) - - BASE = os.path.abspath(os.path.join(__file__, '../../../out/adversarial_data_approx')) - - # plot adversarial datasets - for color, folder in [ - # 'const' datasets - ('purple', '2021-08-18--00-58-22_FINAL-dsprites_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('purple', '2021-08-18--01-33-47_FINAL-shapes3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('purple', '2021-08-18--02-20-13_FINAL-cars3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('purple', '2021-08-18--03-10-53_FINAL-smallnorb_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - # 'invert' datasets - ('orange', '2021-08-18--03-52-31_FINAL-dsprites_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('orange', '2021-08-18--04-29-25_FINAL-shapes3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('orange', '2021-08-18--05-13-15_FINAL-cars3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - ('orange', '2021-08-18--06-03-32_FINAL-smallnorb_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06'), - # stronger 'invert' datasets - ('red', '2021-09-06--00-29-23_INVERT-VSTRONG-shapes3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('red', '2021-09-06--03-17-28_INVERT-VSTRONG-dsprites_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('red', '2021-09-06--05-42-06_INVERT-VSTRONG-cars3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ('red', '2021-09-06--09-10-59_INVERT-VSTRONG-smallnorb_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06'), - ]: - data = _make_self_contained_dataset(f'{BASE}/{folder}/data.h5') - plot_traversal_stats(circular_distance=CIRCULAR, save_path=sp(folder), color=color, dataset_or_name=data, plot_freq=PLOT_FREQ, x_size_offset=0.4) - _print_data_mean_std(data, print_mean_std) - - -# ========================================================================= # -# STATS # -# ========================================================================= # - - -if __name__ == '__main__': - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - # run! - # main_plotting() - main_compute_dists() - - -# ========================================================================= # -# STATS # -# ========================================================================= # - - -# 2021-08-18--00-58-22_FINAL-dsprites_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.04375297] -# vis_std: [0.06837677] -# 2021-08-18--01-33-47_FINAL-shapes3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.48852729, 0.5872147 , 0.59863929] -# vis_std: [0.08931785, 0.18920148, 0.23331079] -# 2021-08-18--02-20-13_FINAL-cars3d_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.88888636, 0.88274618, 0.87782785] -# vis_std: [0.18967542, 0.20009377, 0.20805905] -# 2021-08-18--03-10-53_FINAL-smallnorb_self_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.74029344] -# vis_std: [0.06706581] -# -# 2021-08-18--03-52-31_FINAL-dsprites_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.0493243] -# vis_std: [0.09729655] -# 2021-08-18--04-29-25_FINAL-shapes3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.49514523, 0.58791172, 0.59616399] -# vis_std: [0.08637031, 0.1895267 , 0.23397072] -# 2021-08-18--05-13-15_FINAL-cars3d_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.88851889, 0.88029857, 0.87666017] -# vis_std: [0.200735 , 0.2151134, 0.2217553] -# 2021-08-18--06-03-32_FINAL-smallnorb_invert_margin_0.005_aw10.0_close_p_random_n_s50001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.73232105] -# vis_std: [0.08755041] -# -# 2021-09-06--00-29-23_INVERT-VSTRONG-shapes3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.47992192, 0.51311111, 0.54627272] -# vis_std: [0.28653814, 0.29201543, 0.27395435] -# 2021-09-06--03-17-28_INVERT-VSTRONG-dsprites_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.20482841] -# vis_std: [0.33634909] -# 2021-09-06--05-42-06_INVERT-VSTRONG-cars3d_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.76418207, 0.75554032, 0.75075393] -# vis_std: [0.31892905, 0.32751031, 0.33319886] -# 2021-09-06--09-10-59_INVERT-VSTRONG-smallnorb_invert_margin_0.05_aw10.0_same_k1_close_s200001_Adam_lr0.0005_wd1e-06 -# vis_mean: [0.69691603] -# vis_std: [0.21310608] - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e01_visual_overlap/util_compute_traversal_dist_pairs.py b/research/e01_visual_overlap/util_compute_traversal_dist_pairs.py deleted file mode 100644 index 3ecb9713..00000000 --- a/research/e01_visual_overlap/util_compute_traversal_dist_pairs.py +++ /dev/null @@ -1,274 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -from pathlib import Path -from typing import Optional - -import numpy as np -import psutil -import ray -import torch -from ray.util.queue import Queue -from tqdm import tqdm - -import research.util as H -from disent.dataset.data import GroundTruthData -from disent.util.inout.files import AtomicSaveFile -from disent.util.profiling import Timer -from disent.util.seeds import TempNumpySeed - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Dataset Distances # -# ========================================================================= # - - -@ray.remote -def _compute_given_dists(gt_data, idxs, obs_pair_idxs, progress_queue=None): - # checks - assert idxs.ndim == 1 - assert obs_pair_idxs.ndim == 2 - assert len(obs_pair_idxs) == len(idxs) - # storage - with torch.no_grad(), Timer() as timer: - obs_pair_dists = torch.zeros(*obs_pair_idxs.shape, dtype=torch.float32) - # progress - done = 0 - # for each observation - for i, (obs_idx, pair_idxs) in enumerate(zip(idxs, obs_pair_idxs)): - # load data - obs = gt_data[obs_idx].flatten() - batch = torch.stack([gt_data[i].flatten() for i in pair_idxs], dim=0) - # compute distances - obs_pair_dists[i, :] = torch.mean((batch - obs[None, :])**2, dim=-1, dtype=torch.float32) - # add progress - done += 1 - if progress_queue is not None: - if timer.elapsed > 0.2: - timer.restart() - progress_queue.put(done) - done = 0 - # final update - if progress_queue is not None: - if done > 0: - progress_queue.put(done) - # done! - return obs_pair_dists.numpy() - - -def compute_dists(gt_data: GroundTruthData, obs_pair_idxs: np.ndarray, jobs_per_cpu: int = 1): - """ - Compute all the distances for ground truth data. - - obs_pair_idxs is a 2D array (len(gt_dat), N) that is a list - of paired indices to each element in the dataset. - """ - # checks - assert obs_pair_idxs.ndim == 2 - assert obs_pair_idxs.shape[0] == len(gt_data) - assert jobs_per_cpu > 0 - # get workers - num_cpus = int(ray.available_resources().get('CPU', 1)) - num_workers = int(num_cpus * jobs_per_cpu) - # get chunks - pair_idxs_chunks = np.array_split(obs_pair_idxs, num_workers) - start_idxs = [0] + np.cumsum([len(c) for c in pair_idxs_chunks]).tolist() - # progress queue - progress_queue = Queue() - ref_gt_data = ray.put(gt_data) - # make workers - futures = [ - _compute_given_dists.remote(ref_gt_data, np.arange(i, i+len(chunk)), chunk, progress_queue) - for i, chunk in zip(start_idxs, pair_idxs_chunks) - ] - # check progress - with tqdm(desc=gt_data.name, total=len(gt_data)) as progress: - completed = 0 - while completed < len(gt_data): - done = progress_queue.get() - completed += done - progress.update(done) - # done - obs_pair_dists = np.concatenate(ray.get(futures), axis=0) - return obs_pair_dists - -# ========================================================================= # -# Distance Types # -# ========================================================================= # - - -def dataset_pair_idxs__random(gt_data: GroundTruthData, num_pairs: int = 25) -> np.ndarray: - # purely random pairs... - return np.random.randint(0, len(gt_data), size=[len(gt_data), num_pairs]) - - -def dataset_pair_idxs__nearby(gt_data: GroundTruthData, num_pairs: int = 10, radius: int = 5) -> np.ndarray: - radius = np.array(radius) - assert radius.ndim in (0, 1) - if radius.ndim == 1: - assert radius.shape == (gt_data.num_factors,) - # get all positions - pos = gt_data.idx_to_pos(np.arange(len(gt_data))) - # generate random offsets - offsets = np.random.randint(-radius, radius + 1, size=[len(gt_data), num_pairs, gt_data.num_factors]) - # broadcast random offsets & wrap around - nearby_pos = (pos[:, None, :] + offsets) % gt_data.factor_sizes - # convert back to indices - nearby_idxs = gt_data.pos_to_idx(nearby_pos) - # done! - return nearby_idxs - - -def dataset_pair_idxs__nearby_scaled(gt_data: GroundTruthData, num_pairs: int = 10, min_radius: int = 2, radius_ratio: float = 0.2) -> np.ndarray: - return dataset_pair_idxs__nearby( - gt_data=gt_data, - num_pairs=num_pairs, - radius=np.maximum((np.array(gt_data.factor_sizes) * radius_ratio).astype('int'), min_radius), - ) - - -_PAIR_IDXS_FNS = { - 'random': dataset_pair_idxs__random, - 'nearby': dataset_pair_idxs__nearby, - 'nearby_scaled': dataset_pair_idxs__nearby_scaled, -} - - -def dataset_pair_idxs(mode: str, gt_data: GroundTruthData, num_pairs: int = 10, **kwargs): - if mode not in _PAIR_IDXS_FNS: - raise KeyError(f'invalid mode: {repr(mode)}, must be one of: {sorted(_PAIR_IDXS_FNS.keys())}') - return _PAIR_IDXS_FNS[mode](gt_data, num_pairs=num_pairs, **kwargs) - - -# ========================================================================= # -# Cache Distances # -# ========================================================================= # - -def _get_default_seed( - pairs_per_obs: int, - pair_mode: str, - dataset_name: str, -): - import hashlib - seed_key = (pairs_per_obs, pair_mode, dataset_name) - seed_hash = hashlib.md5(str(seed_key).encode()) - seed = int(seed_hash.hexdigest()[:8], base=16) % (2**32) # [0, 2**32-1] - return seed - - -def cached_compute_dataset_pair_dists( - dataset_name: str = 'smallnorb', - pair_mode: str = 'nearby_scaled', # random, nearby, nearby_scaled - pairs_per_obs: int = 64, - seed: Optional[int] = None, - # cache settings - cache_dir: str = 'data/cache', - force: bool = False, - # normalize - scaled: bool = True, -): - # checks - assert (seed is None) or isinstance(seed, int), f'seed must be an int or None, got: {type(seed)}' - assert isinstance(pairs_per_obs, int), f'pairs_per_obs must be an int, got: {type(pairs_per_obs)}' - assert pair_mode in _PAIR_IDXS_FNS, f'pair_mode is invalid, got: {repr(pair_mode)}, must be one of: {sorted(_PAIR_IDXS_FNS.keys())}' - # get default seed - if seed is None: - seed = _get_default_seed(pairs_per_obs=pairs_per_obs, pair_mode=pair_mode, dataset_name=dataset_name) - # cache path - cache_path = Path(cache_dir, f'dist-pairs_{dataset_name}_{pairs_per_obs}_{pair_mode}_{seed}.npz') - # generate if it does not exist - if force or not cache_path.exists(): - log.info(f'generating cached distances for: {dataset_name} to: {cache_path}') - # load data - gt_data = H.make_data(dataset_name, transform_mode='float32') - # generate idxs - with TempNumpySeed(seed=seed): - obs_pair_idxs = dataset_pair_idxs(pair_mode, gt_data, num_pairs=pairs_per_obs) - obs_pair_dists = compute_dists(gt_data, obs_pair_idxs) - # generate & save - with AtomicSaveFile(file=cache_path, overwrite=force) as path: - np.savez(path, **{ - 'dataset_name': dataset_name, - 'seed': seed, - 'obs_pair_idxs': obs_pair_idxs, - 'obs_pair_dists': obs_pair_dists, - }) - # load cached data - else: - log.info(f'loading cached distances for: {dataset_name} from: {cache_path}') - data = np.load(cache_path) - obs_pair_idxs = data['obs_pair_idxs'] - obs_pair_dists = data['obs_pair_dists'] - # normalize the max distance to 1.0 - if scaled: - obs_pair_dists /= np.max(obs_pair_dists) - # done! - return obs_pair_idxs, obs_pair_dists - - -# ========================================================================= # -# TEST! # -# ========================================================================= # - - -def generate_common_cache(force=False, force_seed=None): - import itertools - # settings - sweep_pairs_per_obs = [128, 32, 256, 64, 16] - sweep_pair_modes = ['nearby_scaled', 'random', 'nearby'] - sweep_dataset_names = ['cars3d', 'smallnorb', 'shapes3d', 'dsprites', 'xysquares'] - # info - log.info(f'Computing distances for sweep of size: {len(sweep_pairs_per_obs)*len(sweep_pair_modes)*len(sweep_dataset_names)}') - # sweep - for i, (pairs_per_obs, pair_mode, dataset_name) in enumerate(itertools.product(sweep_pairs_per_obs, sweep_pair_modes, sweep_dataset_names)): - # deterministic seed based on settings - if force_seed is None: - seed = _get_default_seed(pairs_per_obs=pairs_per_obs, pair_mode=pair_mode, dataset_name=dataset_name) - else: - seed = force_seed - # info - log.info(f'[{i}] Computing distances for: {repr(dataset_name)} {repr(pair_mode)} {repr(pairs_per_obs)} {repr(seed)}') - # get the dataset and delete the transform - cached_compute_dataset_pair_dists( - dataset_name=dataset_name, - pair_mode=pair_mode, - pairs_per_obs=pairs_per_obs, - seed=seed, - force=force, - scaled=True - ) - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - ray.init(num_cpus=psutil.cpu_count(logical=False)) - generate_common_cache() - - -# ========================================================================= # -# DONE # -# ========================================================================= # diff --git a/research/e01_visual_overlap/util_compute_traversal_dists.py b/research/e01_visual_overlap/util_compute_traversal_dists.py deleted file mode 100644 index 14f78d71..00000000 --- a/research/e01_visual_overlap/util_compute_traversal_dists.py +++ /dev/null @@ -1,303 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import warnings -from typing import Sequence - -import psutil -import ray - -import logging -import os -from typing import Tuple - -import numpy as np -import torch -from matplotlib import pyplot as plt -from tqdm import tqdm - -import research.util as H -from disent.dataset.data import GroundTruthData -from disent.dataset.util.state_space import StateSpace -from disent.util.strings.fmt import bytes_to_human - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Dataset Stats # -# ========================================================================= # - - -def factor_dist_matrix_shapes(gt_data: GroundTruthData) -> np.ndarray: - # shape: (f_idx, num_factors + 1) - return np.array([factor_dist_matrix_shape(gt_data=gt_data, f_idx=f_idx) for f_idx in range(gt_data.num_factors)]) - - -def factor_dist_matrix_shape(gt_data: GroundTruthData, f_idx: int) -> Tuple[int, ...]: - # using triangular matrices complicates everything - # (np.prod(self._gt_data.factor_sizes) * self._gt_data.factor_sizes[i]) # symmetric, including diagonal in distance matrix - # (np.prod(self._gt_data.factor_sizes) * (self._gt_data.factor_sizes[i] - 1)) // 2 # upper triangular matrix excluding diagonal - # (np.prod(self._gt_data.factor_sizes) * (self._gt_data.factor_sizes[i] + 1)) // 2 # upper triangular matrix including diagonal - return (*np.delete(gt_data.factor_sizes, f_idx), gt_data.factor_sizes[f_idx], gt_data.factor_sizes[f_idx]) - - -def print_dist_matrix_stats(gt_data: GroundTruthData): - # assuming storage as f32 - num_pairs = factor_dist_matrix_shapes(gt_data).prod(axis=1).sum(axis=0) - pre_compute_bytes = num_pairs * (32 // 8) - pairwise_compute_bytes = num_pairs * (32 // 8) * np.prod(gt_data.x_shape) * 2 - traversal_compute_bytes = np.prod(gt_data.x_shape) * np.prod(gt_data.factor_sizes) * gt_data.num_factors - # string - print( - f'{f"{gt_data.name}:":12s} ' - f'{num_pairs:10d} (pairs) ' - f'{bytes_to_human(pre_compute_bytes):>22s} (pre-comp. f32) ' - f'{"x".join(str(i) for i in gt_data.img_shape):>11s} (obs. size)' - f'{bytes_to_human(pairwise_compute_bytes):>22s} (comp. f32) ' - f'{bytes_to_human(traversal_compute_bytes):>22s} (opt. f32)' - ) - - -# ========================================================================= # -# Dataset Compute # -# ========================================================================= # - - -def _iter_batch_ranges(total, batch_size): - assert total >= 0 - assert batch_size > 0 - for i in range(0, total, batch_size): - yield range(i, min(i + batch_size, total)) - - -def _check_gt_data(gt_data: GroundTruthData): - obs = gt_data[0] - # checks - assert isinstance(obs, torch.Tensor) - assert obs.dtype == torch.float32 - - -@ray.remote -def _compute_dists( - idxs: Sequence[int], - # thread data - f_states: StateSpace, - f_idx: int, - gt_data: GroundTruthData, - masked: bool, - a_idxs: np.ndarray, - b_idxs: np.ndarray, -): - results = [] - for idx in idxs: - # translate traversal position to dataset position - base_pos = f_states.idx_to_pos(int(idx)) - base_factors = np.insert(base_pos, f_idx, 0) - # load traversal: (f_size, H*W*C) - traversal = [gt_data[i].flatten().numpy() for i in gt_data.iter_traversal_indices(f_idx=f_idx, base_factors=base_factors)] - traversal = np.stack(traversal, axis=0) - # compute distances - if masked: - B, NUM = traversal.shape - # compute mask - mask = (traversal[0] != traversal[1]) - for item in traversal[2:]: - mask |= (traversal[0] != item) - traversal = traversal[:, mask] - # compute distances - dists = np.sum((traversal[a_idxs] - traversal[b_idxs]) ** 2, axis=1, dtype='float32') / NUM # might need to be float64 - else: - dists = np.mean((traversal[a_idxs] - traversal[b_idxs]) ** 2, axis=1, dtype='float32') - # return data - results.append((base_pos, dists)) - return results - - -def get_as_completed(obj_ids): - # https://github.com/ray-project/ray/issues/5554 - while obj_ids: - done, obj_ids = ray.wait(obj_ids) - yield ray.get(done[0]) - - -@torch.no_grad() -def compute_factor_dist_matrices( - gt_data: GroundTruthData, - f_idx: int, - masked: bool = True, - traversals_per_batch: int = 64, -): - if not ray.is_initialized(): - warnings.warn(f'Ray has not yet been initialized, consider calling `ray.init(...)` and specifying the CPU requirements.') - _check_gt_data(gt_data) - # load data - f_states = StateSpace(factor_sizes=np.delete(gt_data.factor_sizes, f_idx)) - a_idxs, b_idxs = H.pair_indices_combinations(gt_data.factor_sizes[f_idx]) - total = len(f_states) - # move to shared memory - ID_f_states = ray.put(f_states) - ID_gt_data = ray.put(gt_data) - ID_a_idxs = ray.put(a_idxs) - ID_b_idxs = ray.put(b_idxs) - # results - f_dist_matrices = np.zeros(factor_dist_matrix_shape(gt_data=gt_data, f_idx=f_idx), dtype='float32') - # generate futures - futures = [ - _compute_dists.remote( - idxs=sub_range, - f_idx=f_idx, - masked=masked, - f_states=ID_f_states, - gt_data=ID_gt_data, - a_idxs=ID_a_idxs, - b_idxs=ID_b_idxs, - ) - for sub_range in _iter_batch_ranges(total, batch_size=traversals_per_batch) - ] - # apply multithreading to compute traversal distances - with tqdm(total=total, desc=f'{gt_data.name}: {f_idx+1} of {gt_data.num_factors}') as p: - # compute distance matrices - for results in get_as_completed(futures): - for base_pos, dists in results: - f_dist_matrices[(*base_pos, a_idxs, b_idxs)] = dists - f_dist_matrices[(*base_pos, b_idxs, a_idxs)] = dists - p.update(len(results)) - # return distances - return f_dist_matrices - - -def compute_all_factor_dist_matrices( - gt_data: GroundTruthData, - masked: bool = True, - traversals_per_batch: int = 64, -): - """ - ALGORITHM: - for each factor: O(num_factors) - for each traversal: O(prod()) - for element in traversal: O(n) - -- compute overlapping mask - -- we use this mask to only transfer and compute over the needed data - -- we transfer the MASKED traversal to the GPU not the pairs - for each pair in the traversal: O(n*(n-1)/2) | O(n**2) - -- compute each unique pairs distance - -- return distances - """ - # for each factor, compute pairwise overlap - all_dist_matrices = [] - for f_idx in range(gt_data.num_factors): - f_dist_matrices = compute_factor_dist_matrices( - gt_data=gt_data, - f_idx=f_idx, - masked=masked, - traversals_per_batch=traversals_per_batch, - ) - all_dist_matrices.append(f_dist_matrices) - return all_dist_matrices - - -# TODO: replace this with cachier maybe? -def cached_compute_all_factor_dist_matrices( - dataset_name: str = 'smallnorb', - masked: bool = False, - traversals_per_batch: int = 64, - # cache settings - cache_dir: str = 'data/cache', - force: bool = False, - # normalize - normalize_mode: str = 'all', -): - import os - from disent.util.inout.files import AtomicSaveFile - # load data - gt_data = H.make_data(dataset_name, transform_mode='float32') - # check cache - name = f'dist-matrices_{dataset_name}_masked.npz' if masked else f'dist-matrices_{dataset_name}_full.npz' - cache_path = os.path.abspath(os.path.join(cache_dir, name)) - # generate if it does not exist - if force or not os.path.exists(cache_path): - log.info(f'generating cached distances for: {dataset_name} to: {cache_path}') - # generate & save - with AtomicSaveFile(file=cache_path, overwrite=force) as path: - all_dist_matrices = compute_all_factor_dist_matrices(gt_data, masked=masked, traversals_per_batch=traversals_per_batch) - np.savez(path, **{f_name: f_dists for f_name, f_dists in zip(gt_data.factor_names, all_dist_matrices)}) - # load data - log.info(f'loading cached distances for: {dataset_name} from: {cache_path}') - data = np.load(cache_path) - dist_mats = [data[f_name] for f_name in gt_data.factor_names] - # normalize the max distance to 1.0 - if (normalize_mode == 'none') or (normalize_mode is None): - pass - elif normalize_mode == 'all': - M = np.max([np.max(v) for v in dist_mats]) - dist_mats = [v / M for v in dist_mats] - log.info(f'normalized max over all distances: {M} to 1.0') - elif normalize_mode == 'each': - Ms = [v.max() for v in dist_mats] - dist_mats = [v / M for v, M in zip(dist_mats, Ms)] - log.info(f'normalized max over each factor distance: {Ms} to 1.0') - else: - raise KeyError(f'invalid normalize mode: {repr(normalize_mode)}') - - # done! - return dist_mats - - -# ========================================================================= # -# TEST! # -# ========================================================================= # - - -def generate_common_cache(): - for name in ['cars3d', 'smallnorb', 'shapes3d', 'dsprites', 'xysquares']: - # get the dataset and delete the transform - gt_data = H.make_data(name, transform_mode='float32') - print_dist_matrix_stats(gt_data) - f_dist_matrices = cached_compute_all_factor_dist_matrices( - dataset_name=name, - force=True, - masked=True, - traversals_per_batch=32, - ) - # plot distance matrices - H.plt_subplots_imshow( - grid=[[d.reshape([-1, *d.shape[-2:]]).mean(axis=0) for d in f_dist_matrices]], - subplot_padding=0.5, - figsize=(20, 10), - ) - plt.show() - - -def _test_masked_equals_unmasked(): - for name in ['cars3d', 'smallnorb', 'shapes3d', 'dsprites', 'xysquares']: - dists_a = compute_all_factor_dist_matrices(gt_data=H.make_data(name, transform_mode='float32'), masked=True, traversals_per_batch=32) - dists_b = compute_all_factor_dist_matrices(gt_data=H.make_data(name, transform_mode='float32'), masked=False, traversals_per_batch=32) - for a, b in zip(dists_a, dists_b): - assert np.allclose(a, b) - - -if __name__ == '__main__': - ray.init(num_cpus=min(os.cpu_count(), 32)) - generate_common_cache() diff --git a/research/e02_naive_triplet/submit_01_triplet_hparam_sweep.sh b/research/e02_naive_triplet/submit_01_triplet_hparam_sweep.sh deleted file mode 100644 index e6b2f644..00000000 --- a/research/e02_naive_triplet/submit_01_triplet_hparam_sweep.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-02__naive-triplet-hparams" -export PARTITION="batch" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -# general sweep of hyper parameters for triplet -# 1 * (3*3*3*2*3 = 162) = 162 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_tvae_params' \ - \ - run_length=long \ - metrics=all \ - \ - framework=tvae \ - settings.framework.beta=0.0316,0.01,0.1 \ - \ - framework.cfg.triplet_margin_max=0.1,1.0,10.0 \ - framework.cfg.triplet_scale=0.1,1.0,0.01 \ - framework.cfg.triplet_p=1,2 \ - \ - dataset=xysquares,cars3d,smallnorb \ - sampling=gt_dist__manhat - -# check sampling strategy -# 2 * (4 * 5 = 20) = 40 -echo PARAMS NOT SET FROM PREVIOUS SWEEP -exit 1 - -# TODO: set the parameters -submit_sweep \ - +DUMMY.repeat=1,2 \ - +EXTRA.tags='sweep_tvae_sampling' \ - \ - run_length=long \ - metrics=all \ - \ - framework=tvae \ - settings.framework.beta=??? \ - \ - framework.cfg.triplet_margin_max=??? \ - framework.cfg.triplet_scale=??? \ - framework.cfg.triplet_p=??? \ - \ - dataset=xysquares,cars3d,shapes3d,dsprites,smallnorb \ - sampling=gt_dist__manhat_scaled,gt_dist__manhat,gt__dist_combined,gt_dist__factors diff --git a/research/e02_naive_triplet/submit_02_check_vae_equivalence.sh b/research/e02_naive_triplet/submit_02_check_vae_equivalence.sh deleted file mode 100644 index a07c4783..00000000 --- a/research/e02_naive_triplet/submit_02_check_vae_equivalence.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-02__naive-triplet-equivalence" -export PARTITION="batch" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -# make sure the tvae is actually working -# like a vae when the triplet loss is disabled -# 1 * (4=4) = 4 -submit_sweep \ - +DUMMY.repeat=1,2 \ - +EXTRA.tags='check_equivalence' \ - \ - run_length=medium \ - metrics=all \ - \ - framework=tvae \ - framework.cfg.triplet_scale=0.0 \ - settings.framework.beta=0.0316 \ - \ - dataset=xysquares \ - sampling=gt_dist__manhat_scaled,gt_dist__manhat,gt__dist_combined,gt_dist__factors - -# check how sampling effects beta and adavae -# 2 * (2*3=6) = 12 -submit_sweep \ - +DUMMY.repeat=1,2 \ - +EXTRA.tags='check_vae_sampling' \ - \ - run_length=medium \ - metrics=all \ - \ - framework=betavae,adavae \ - settings.framework.beta=0.0316 \ - \ - dataset=xysquares \ - sampling=gt_dist__manhat_scaled,gt_dist__manhat,gt__dist_combined,gt_dist__factors diff --git a/research/e03_axis_triplet/submit_01.sh b/research/e03_axis_triplet/submit_01.sh deleted file mode 100644 index e86b94ea..00000000 --- a/research/e03_axis_triplet/submit_01.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__axis-triplet-3.0" -export PARTITION="batch" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 43200 "C-disent" # 12 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# SHORT RUNS: -# - test for best ada loss types -# 1 * (2*4*2*8=112) = 128 -submit_sweep \ - +DUMMY.repeat=1 \ - \ - framework=X--adatvae \ - dataset=xysquares \ - run_length=short \ - \ - framework.cfg.triplet_margin_max=1.0 \ - framework.cfg.triplet_scale=0.1 \ - framework.cfg.triplet_p=1 \ - sampling=gt_dist_manhat \ - \ - model.z_size=25,9 \ - \ - framework.cfg.thresh_ratio=0.5 \ - framework.cfg.ada_triplet_ratio=1.0 \ - schedule=adavae_thresh,adavae_all,adavae_ratio,none \ - framework.cfg.ada_triplet_sample=TRUE,FALSE \ - framework.cfg.ada_triplet_loss=framework.cfg.ada_triplet_loss=triplet,triplet_soft_ave,triplet_soft_neg_ave,triplet_all_soft_ave,triplet_hard_ave,triplet_hard_neg_ave,triplet_hard_neg_ave_pull,triplet_all_hard_ave - -# ADA TRIPLET LOSS MODES (short runs): -# - generally dont use sampling, except for: triplet_hard_neg_ave_pull -# - soft averages dont work if scheduling thresh or ratio separately, need to do both at the same time -# - hard averages perform well initially, but performance decays more toward the end of schedules -# ======================= -# [X] triplet -# -# [-] triplet_soft_ave [NOTE: OK, but just worse than, triplet_all_soft_ave] -# triplet_soft_neg_ave [NOTE: better disentanglement than triplet_all_soft_ave, but worse axis align] -# triplet_all_soft_ave -# -# triplet_hard_neg_ave -# triplet_hard_neg_ave_pull (weight = 0.1, triplet_hard_neg_ave_pull_soft) -# [X] triplet_hard_ave -# [X] triplet_hard_neg_ave_pull (weight = 1.0) -# [X] triplet_all_hard_ave diff --git a/research/e03_axis_triplet/submit_02.sh b/research/e03_axis_triplet/submit_02.sh deleted file mode 100644 index 76e4b1dd..00000000 --- a/research/e03_axis_triplet/submit_02.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__axis-triplet-3.0" -export PARTITION="batch" -export PARALLELISM=30 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# MED RUNS: -# - test for best hparams for all soft ave loss -# 2 * (2*3*3*3=54) = 104 -submit_sweep \ - +DUMMY.repeat=1,2 \ - +EXTRA.tags='med-run+soft-hparams' \ - \ - framework=X--adatvae \ - run_length=medium \ - model.z_size=25 \ - \ - framework.cfg.triplet_margin_max=1.0,5.0 \ - framework.cfg.triplet_scale=0.1,0.02,0.5 \ - framework.cfg.triplet_p=1 \ - sampling=gt_dist_manhat \ - \ - framework.cfg.thresh_ratio=0.5 \ - framework.cfg.ada_triplet_ratio=1.0 \ - framework.cfg.ada_triplet_soft_scale=0.25,1.0,4.0 \ - framework.cfg.ada_triplet_sample=FALSE \ - \ - schedule=adavae_all,adavae_thresh,adavae_ratio \ - framework.cfg.ada_triplet_loss=triplet_all_soft_ave \ - dataset=xysquares diff --git a/research/e03_axis_triplet/submit_03.sh b/research/e03_axis_triplet/submit_03.sh deleted file mode 100644 index 4317e923..00000000 --- a/research/e03_axis_triplet/submit_03.sh +++ /dev/null @@ -1,78 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__axis-triplet-3.0" -export PARTITION="stampede" -export PARALLELISM=32 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# LONG RUNS: -# - test best losses & best hparams from test1 on different datasets with long runs -# + [not tested] triplet_soft_neg_ave -# + triplet_all_soft_ave -# + triplet_hard_neg_ave -# + triplet_hard_neg_ave_pull - -# 1 * (2*3*4*4=96) = 96 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='long-run' \ -# \ -# framework=X--adatvae \ -# run_length=long \ -# model.z_size=25 \ -# \ -# framework.cfg.triplet_margin_max=1.0 \ -# framework.cfg.triplet_scale=0.1 \ -# framework.cfg.triplet_p=1 \ -# sampling=gt_dist_manhat,gt_dist_manhat_scaled \ -# \ -# framework.cfg.thresh_ratio=0.5 \ -# framework.cfg.ada_triplet_ratio=1.0 \ -# framework.cfg.ada_triplet_soft_scale=1.0 \ -# framework.cfg.ada_triplet_sample=FALSE \ -# \ -# schedule=adavae_all,adavae_thresh,adavae_ratio \ -# framework.cfg.ada_triplet_loss=triplet,triplet_all_soft_ave,triplet_hard_neg_ave,triplet_hard_neg_ave_pull \ -# dataset=xysquares,shapes3d,cars3d,dsprites - -# 2*2*3*4*4 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='med-run+datasets+swap-chance+manhat-scaled' \ - \ - framework=X--adatvae \ - run_length=medium \ - model.z_size=25 \ - \ - sampling=gt_dist_manhat_scaled,gt_dist_manhat \ - schedule=adavae_all,adavae_thresh,adavae_ratio \ - sampling.triplet_swap_chance=0,0.1 \ - \ - framework.cfg.triplet_margin_max=1.0 \ - framework.cfg.triplet_scale=0.1 \ - framework.cfg.triplet_p=1 \ - \ - framework.cfg.thresh_ratio=0.5 \ - framework.cfg.ada_triplet_ratio=1.0 \ - framework.cfg.ada_triplet_soft_scale=1.0 \ - framework.cfg.ada_triplet_sample=FALSE \ - \ - framework.cfg.ada_triplet_loss=triplet,triplet_all_soft_ave,triplet_hard_neg_ave,triplet_hard_neg_ave_pull \ - dataset=xysquares,shapes3d,cars3d,dsprites diff --git a/research/e03_axis_triplet/submit_04.sh b/research/e03_axis_triplet/submit_04.sh deleted file mode 100644 index b44ae30f..00000000 --- a/research/e03_axis_triplet/submit_04.sh +++ /dev/null @@ -1,119 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__axis-triplet-4.0" -export PARTITION="stampede" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# RESULT: -# - BAD: ada_thresh_mode=symmetric_kl, rather use "dist" -# - BAD: framework.cfg.adaave_decode_orig=FALSE, rather use TRUE -# - adat_share_ave_mode depends on other settings, but usually doesnt matter -# - adaave_augment_orig depends on other settings, but usually doesnt matter -# - GOOD: adat_triplet_loss=triplet_hard_neg_ave -# - NOTE: schedule=adavae_up_ratio usually converges sooner -# - NOTE: schedule=adavae_up_all usually converges later (makes sense because its a doubling effect a ^ 2) -# - NOTE: schedule=adavae_up_thresh usually is worse at converging - - -# 3*2*4*2*2*2 == 192 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='short-run__ada-best-loss-combo' \ - \ - framework=X--adaavetvae \ - run_length=short \ - model.z_size=25 \ - \ - schedule=adavae_up_all,adavae_up_ratio,adavae_up_thresh \ - sampling=gt_dist_manhat \ - sampling.triplet_swap_chance=0 \ - dataset=xysquares \ - \ - framework.cfg.triplet_loss=triplet \ - framework.cfg.triplet_margin_min=0.001 \ - framework.cfg.triplet_margin_max=1 \ - framework.cfg.triplet_scale=0.1 \ - framework.cfg.triplet_p=1 \ - \ - framework.cfg.detach=FALSE \ - framework.cfg.detach_decoder=FALSE \ - framework.cfg.detach_no_kl=FALSE \ - framework.cfg.detach_std=NULL \ - \ - framework.module.ada_average_mode=gvae \ - framework.module.ada_thresh_mode=symmetric_kl,dist \ - framework.module.ada_thresh_ratio=0.5 \ - \ - framework.module.adat_triplet_loss=triplet,triplet_soft_ave_all,triplet_hard_neg_ave,triplet_hard_ave_all \ - framework.module.adat_triplet_ratio=1.0 \ - framework.module.adat_triplet_soft_scale=1.0 \ - framework.module.adat_triplet_pull_weight=0.1 \ - \ - framework.module.adat_share_mask_mode=posterior \ - framework.module.adat_share_ave_mode=all,neg \ - \ - framework.module.adaave_augment_orig=TRUE,FALSE \ - framework.module.adaave_decode_orig=TRUE,FALSE - -# TRY THESE TOO: -# framework.module.adat_share_ave_mode=all,neg,pos,pos_neg \ -# framework.module.adat_share_mask_mode=posterior,sample,sample_each \ -# framework.module.adat_triplet_loss=triplet,triplet_soft_ave_all,triplet_hard_neg_ave,triplet_hard_neg_ave_pull,triplet_hard_ave_all \ - -# # 3*2*8*2*3*2*2 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='short-run__ada-best-loss-combo' \ -# \ -# framework=X--adaavetvae \ -# run_length=short \ -# model.z_size=25 \ -# \ -# schedule=adavae_all,adavae_thresh,adavae_ratio \ -# sampling=gt_dist_manhat \ -# sampling.triplet_swap_chance=0 \ -# dataset=xysquares \ -# \ -# triplet_loss=triplet \ -# triplet_margin_min=0.001 \ -# triplet_margin_max=1 \ -# triplet_scale=0.1 \ -# triplet_p=1 \ -# \ -# detach=FALSE \ -# disable_decoder=FALSE \ -# detach_no_kl=FALSE \ -# detach_std=NULL \ -# \ -# ada_average_mode=gvae \ -# ada_thresh_mode=symmetric_kl,dist \ -# ada_thresh_ratio=0.5 \ -# \ -# adat_triplet_loss=triplet,triplet_soft_ave_neg,triplet_soft_ave_p_n,triplet_soft_ave_all,triplet_hard_ave,triplet_hard_neg_ave,triplet_hard_neg_ave_pull,triplet_hard_ave_all \ -# adat_triplet_ratio=1.0 \ -# adat_triplet_soft_scale=1.0 \ -# adat_triplet_pull_weight=0.1 \ -# \ -# adat_share_mask_mode=posterior,dist \ -# adat_share_ave_mode=all,pos_neg,pos,neg \ -# \ -# adaave_augment_orig=TRUE,FALSE \ -# adaave_decode_orig=TRUE,FALSE diff --git a/research/e03_axis_triplet/submit_05.sh b/research/e03_axis_triplet/submit_05.sh deleted file mode 100644 index 5ea5025f..00000000 --- a/research/e03_axis_triplet/submit_05.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__axis-triplet-5.0" -export PARTITION="stampede" -export PARALLELISM=16 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# 1 * (3*6*5) == 90 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='ada-best-pull-weight' \ - \ - framework=X--adanegtvae \ - run_length=short,medium,long \ - model.z_size=25 \ - \ - schedule=adavae_down_all,adavae_up_all,adavae_down_ratio,adavae_up_ratio,adavae_down_thresh,adavae_up_thresh \ - sampling=gt_dist_manhat \ - sampling.triplet_swap_chance=0 \ - dataset=xysquares \ - \ - framework.cfg.triplet_loss=triplet \ - framework.cfg.triplet_margin_min=0.001 \ - framework.cfg.triplet_margin_max=1 \ - framework.cfg.triplet_scale=0.1 \ - framework.cfg.triplet_p=1 \ - \ - framework.cfg.detach=FALSE \ - framework.cfg.detach_decoder=FALSE \ - framework.cfg.detach_no_kl=FALSE \ - framework.cfg.detach_std=NULL \ - \ - framework.cfg.ada_average_mode=gvae \ - framework.cfg.ada_thresh_mode=dist \ - framework.cfg.ada_thresh_ratio=0.5 \ - \ - framework.cfg.adat_triplet_ratio=1.0 \ - framework.cfg.adat_triplet_pull_weight=-1.0,-0.1,0.0,0.1,1.0 \ - \ - framework.cfg.adat_share_mask_mode=posterior diff --git a/research/e04_data_overlap_triplet/submit_01.sh b/research/e04_data_overlap_triplet/submit_01.sh deleted file mode 100644 index 2c4f2630..00000000 --- a/research/e04_data_overlap_triplet/submit_01.sh +++ /dev/null @@ -1,62 +0,0 @@ -##!/bin/bash -# -## ========================================================================= # -## Settings # -## ========================================================================= # -# -#export USERNAME="n_michlo" -#export PROJECT="final-04__data-overlap-triplet" -#export PARTITION="stampede" -#export PARALLELISM=32 -# -## source the helper file -#source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" -# -## ========================================================================= # -## Experiment # -## ========================================================================= # -# -#clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours -# -## 1 * (3*2*2*5*2) == 120 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='med-best' \ -# \ -# framework=X--dotvae_aug \ -# run_length=medium \ -# model.z_size=25 \ -# \ -# schedule=adavae_up_all,adavae_up_ratio,none \ -# sampling=gt_dist_manhat \ -# sampling.triplet_swap_chance=0 \ -# dataset=xysquares \ -# \ -# framework.cfg.triplet_loss=triplet \ -# framework.cfg.triplet_margin_min=0.001 \ -# framework.cfg.triplet_margin_max=1 \ -# framework.cfg.triplet_scale=0.1,0.01 \ -# framework.cfg.triplet_p=1 \ -# \ -# framework.cfg.detach=FALSE \ -# framework.cfg.disable_decoder=FALSE \ -# framework.cfg.detach_no_kl=FALSE \ -# framework.cfg.detach_std=NULL \ -# \ -# framework.cfg.ada_average_mode=gvae \ -# framework.cfg.ada_thresh_mode=dist \ -# framework.cfg.ada_thresh_ratio=0.5 \ -# \ -# framework.cfg.adat_triplet_share_scale=0.95 \ -# \ -# framework.cfg.adat_share_mask_mode=posterior \ -# \ -# framework.cfg.overlap_num=4096 \ -# framework.cfg.overlap_mine_ratio=0.05,0.1 \ -# framework.cfg.overlap_mine_triplet_mode=none,hard_neg,semi_hard_neg,hard_pos,easy_pos \ -# \ -# framework.cfg.overlap_augment_mode='augment' \ -# framework.cfg.overlap_augment.p=1.0 \ -# framework.cfg.overlap_augment.radius=[61,61],[0,61] \ -# framework.cfg.overlap_augment.random_mode='batch' \ -# framework.cfg.overlap_augment.random_same_xy=TRUE diff --git a/research/e04_data_overlap_triplet/submit_02.sh b/research/e04_data_overlap_triplet/submit_02.sh deleted file mode 100644 index 81865f40..00000000 --- a/research/e04_data_overlap_triplet/submit_02.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-04__data-overlap-triplet" -export PARTITION="batch" -export PARALLELISM=16 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# 1 * (2*8*4) == 64 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='best-augment-strength__alt' \ - \ - framework=X--dotvae_aug \ - run_length=short \ - model=conv64alt \ - model.z_size=25 \ - \ - schedule=adavae_up_ratio_full,adavae_up_all_full \ - sampling=gt_dist_manhat \ - sampling.triplet_swap_chance=0 \ - dataset=xysquares \ - \ - framework.cfg.triplet_loss=triplet \ - framework.cfg.triplet_margin_min=0.001 \ - framework.cfg.triplet_margin_max=1 \ - framework.cfg.triplet_scale=0.1 \ - framework.cfg.triplet_p=1 \ - \ - framework.cfg.detach=FALSE \ - framework.cfg.detach_decoder=FALSE \ - framework.cfg.detach_no_kl=FALSE \ - framework.cfg.detach_std=NULL \ - \ - framework.cfg.ada_average_mode=gvae \ - framework.cfg.ada_thresh_mode=dist \ - framework.cfg.ada_thresh_ratio=0.5 \ - \ - framework.cfg.adat_triplet_share_scale=1.0 \ - \ - framework.cfg.adat_share_mask_mode=posterior \ - \ - framework.cfg.overlap_augment_mode='augment' \ - framework.cfg.overlap_augment.kernel=xy1_r47,xy8_r47,box_r47,gau_r47 \ - \ - framework.cfg.overlap_num=4096 \ - framework.module.overlap_mine_ratio=0.1 \ - framework.module.overlap_mine_triplet_mode=none,hard_neg,semi_hard_neg,hard_pos,easy_pos,ran:hard_neg+hard_pos,ran:hard_neg+easy_pos,ran:hard_pos+easy_pos - - # framework.module.overlap_augment.kernel=xy1_r47,xy8_r47,box_r47,gau_r47,box_r15,box_r31,box_r63,gau_r15,gau_r31,gau_r63 diff --git a/research/e04_data_overlap_triplet/submit_03_test_softada_vs_ada.sh b/research/e04_data_overlap_triplet/submit_03_test_softada_vs_ada.sh deleted file mode 100644 index 494cc903..00000000 --- a/research/e04_data_overlap_triplet/submit_03_test_softada_vs_ada.sh +++ /dev/null @@ -1,65 +0,0 @@ -#!/bin/bash - -# -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="test-hard-vs-soft-ada" -export PARTITION="stampede" -export PARALLELISM=16 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# 3 * (3 * 1) = 9 -submit_sweep \ - +DUMMY.repeat=1,2,3 \ - +EXTRA.tags='sweep_02' \ - \ - run_length=medium \ - metrics=all \ - \ - framework.beta=1 \ - framework=adavae_os,adagvae_minimal_os,X--softadagvae_minimal_os \ - model.z_size=25 \ - \ - dataset=shapes3d \ - \ - hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97"' # we don't want to sweep over these diff --git a/research/e05_disentangle_kernel/run_01_sort_loss.py b/research/e05_disentangle_kernel/run_01_sort_loss.py deleted file mode 100644 index 710b2f3b..00000000 --- a/research/e05_disentangle_kernel/run_01_sort_loss.py +++ /dev/null @@ -1,80 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader - -import research.util as H -from disent.nn.loss.softsort import multi_spearman_rank_loss -from disent.nn.loss.softsort import torch_soft_rank - - -# ========================================================================= # -# tests # -# ========================================================================= # - - -def run_differentiable_sorting_loss(dataset='dsprites', loss_mode='spearman', optimizer='adam', lr=1e-2): - """ - test that the differentiable sorting works over a batch of images. - """ - - dataset = H.make_dataset(dataset) - dataloader = DataLoader(dataset=dataset, batch_size=256, pin_memory=True, shuffle=True) - - y = H.get_single_batch(dataloader) - # y += torch.randn_like(y) * 0.001 # prevent nan errors - x = torch.randn_like(y, requires_grad=True) - - optimizer = H.make_optimizer(x, name=optimizer, lr=lr) - - for i in range(1001): - if loss_mode == 'spearman': - loss = multi_spearman_rank_loss(x, y, dims=(2, 3), nan_to_num=True) - elif loss_mode == 'mse_rank': - loss = 0. - loss += F.mse_loss(torch_soft_rank(x, dims=(-3, -1)), torch_soft_rank(y, dims=(-3, -1)), reduction='mean') - loss += F.mse_loss(torch_soft_rank(x, dims=(-3, -2)), torch_soft_rank(y, dims=(-3, -2)), reduction='mean') - elif loss_mode == 'mse': - loss += F.mse_loss(x, y, reduction='mean') - else: - raise KeyError(f'invalid loss mode: {repr(loss_mode)}') - - # update variables - H.step_optimizer(optimizer, loss) - if i % 250 == 0: - H.plt_imshow(H.to_img(x[0]), show=True) - - # compute loss - print(i, float(loss)) - - -# ========================================================================= # -# MAIN # -# ========================================================================= # - - -if __name__ == '__main__': - run_differentiable_sorting_loss() diff --git a/research/e05_disentangle_kernel/run_02_check_aug_gt_dists.py b/research/e05_disentangle_kernel/run_02_check_aug_gt_dists.py deleted file mode 100644 index 9dd69e0e..00000000 --- a/research/e05_disentangle_kernel/run_02_check_aug_gt_dists.py +++ /dev/null @@ -1,168 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -import numpy as np -import torch -import torch.nn.functional as F -from matplotlib import pyplot as plt -from tqdm import tqdm - -import research.util as H -from disent.nn.functional import torch_box_kernel_2d -from disent.nn.functional import torch_conv2d_channel_wise_fft -from disent.nn.functional import torch_gaussian_kernel_2d - - -# ========================================================================= # -# distance function # -# ========================================================================= # - - -def spearman_rank_dist( - pred: torch.Tensor, - targ: torch.Tensor, - reduction='mean', - nan_to_num=False, -): - # add missing dim - if pred.ndim == 1: - pred, targ = pred.reshape(1, -1), targ.reshape(1, -1) - assert pred.shape == targ.shape - assert pred.ndim == 2 - # sort the last dimension of the 2D tensors - pred = torch.argsort(pred).to(torch.float32) - targ = torch.argsort(targ).to(torch.float32) - # compute individual losses - # TODO: this can result in nan values, what to do then? - pred = pred - pred.mean(dim=-1, keepdim=True) - pred = pred / pred.norm(dim=-1, keepdim=True) - targ = targ - targ.mean(dim=-1, keepdim=True) - targ = targ / targ.norm(dim=-1, keepdim=True) - # replace nan values - if nan_to_num: - pred = torch.nan_to_num(pred, nan=0.0) - targ = torch.nan_to_num(targ, nan=0.0) - # compute the final loss - loss = (pred * targ).sum(dim=-1) - # reduce the loss - if reduction == 'mean': - return loss.mean() - elif reduction == 'none': - return loss - else: - raise KeyError(f'Invalid reduction mode: {repr(reduction)}') - - -def check_xy_squares_dists(kernel='box', repeats=100, samples=256, pairwise_samples=256, kernel_radius=32, show_prog=True): - if kernel == 'box': - kernel = torch_box_kernel_2d(radius=kernel_radius)[None, ...] - elif kernel == 'max_box': - crange = torch.abs(torch.arange(kernel_radius * 2 + 1) - kernel_radius) - y, x = torch.meshgrid(crange, crange) - d = torch.maximum(x, y) + 1 - d = d.max() - d - kernel = (d.to(torch.float32) / d.sum())[None, None, ...] - elif kernel == 'min_box': - crange = torch.abs(torch.arange(kernel_radius * 2 + 1) - kernel_radius) - y, x = torch.meshgrid(crange, crange) - d = torch.minimum(x, y) + 1 - d = d.max() - d - kernel = (d.to(torch.float32) / d.sum())[None, None, ...] - elif kernel == 'manhat_box': - crange = torch.abs(torch.arange(kernel_radius * 2 + 1) - kernel_radius) - y, x = torch.meshgrid(crange, crange) - d = (y + x) + 1 - d = d.max() - d - kernel = (d.to(torch.float32) / d.sum())[None, None, ...] - elif kernel == 'gaussian': - kernel = torch_gaussian_kernel_2d(sigma=kernel_radius / 4.0, truncate=4.0)[None, None, ...] - else: - raise KeyError(f'invalid kernel mode: {repr(kernel)}') - - # make dataset - dataset = H.make_dataset('xysquares') - - losses = [] - prog = tqdm(range(repeats), postfix={'loss': 0.0}) if show_prog else range(repeats) - - for i in prog: - # get random samples - factors = dataset.sample_factors(samples) - batch = dataset.dataset_batch_from_factors(factors, mode='target') - if torch.cuda.is_available(): - batch = batch.cuda() - kernel = kernel.cuda() - factors = torch.from_numpy(factors).to(dtype=torch.float32, device=batch.device) - - # random pairs - ia, ib = torch.randint(0, len(batch), size=(2, pairwise_samples), device=batch.device) - - # compute factor distances - f_dists = torch.abs(factors[ia] - factors[ib]).sum(dim=-1) - - # compute loss distances - aug_batch = torch_conv2d_channel_wise_fft(batch, kernel) - # TODO: aug - batch or aug - aug - # b_dists = torch.abs(aug_batch[ia] - aug_batch[ib]).sum(dim=(-3, -2, -1)) - b_dists = F.mse_loss(aug_batch[ia], aug_batch[ib], reduction='none').sum(dim=(-3, -2, -1)) - - # compute ranks - # losses.append(float(torch.clamp(torch_mse_rank_loss(b_dists, f_dists), 0, 100))) - # losses.append(float(torch.abs(torch.argsort(f_dists, descending=True) - torch.argsort(b_dists, descending=False)).to(torch.float32).mean())) - losses.append(float(spearman_rank_dist(b_dists, f_dists))) - - if show_prog: - prog.set_postfix({'loss': np.mean(losses)}) - - return np.mean(losses), aug_batch[0] - - -def run_check_all_xy_squares_dists(show=False): - for kernel in [ - 'box', - 'max_box', - 'min_box', - 'manhat_box', - 'gaussian', - ]: - rs = list(range(1, 33, 4)) - ys = [] - for r in rs: - ave_spearman, last_img = check_xy_squares_dists(kernel=kernel, repeats=32, samples=128, pairwise_samples=1024, kernel_radius=r, show_prog=False) - H.plt_imshow(H.to_img(last_img, scale=True), show=show) - ys.append(abs(ave_spearman)) - print(kernel, r, ':', r*2+1, abs(ave_spearman)) - plt.plot(rs, ys, label=kernel) - plt.legend() - plt.show() - - -# ========================================================================= # -# MAIN # -# ========================================================================= # - - -if __name__ == '__main__': - run_check_all_xy_squares_dists() diff --git a/research/e05_disentangle_kernel/run_03_train_disentangle_kernel.py b/research/e05_disentangle_kernel/run_03_train_disentangle_kernel.py deleted file mode 100644 index eb15d491..00000000 --- a/research/e05_disentangle_kernel/run_03_train_disentangle_kernel.py +++ /dev/null @@ -1,297 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import logging -import os -from typing import List -from typing import Optional -from typing import Sequence - -import hydra -import numpy as np -import pytorch_lightning as pl -import torch -import wandb -from omegaconf import OmegaConf -from torch.nn import Parameter -from torch.utils.data import DataLoader - -import disent.util.seeds -import research.util as H -from disent.nn.functional import torch_conv2d_channel_wise_fft -from disent.nn.loss.softsort import spearman_rank_loss -from disent.nn.modules import DisentLightningModule -from disent.nn.modules import DisentModule -from disent.util.lightning.callbacks import BaseCallbackPeriodic -from disent.util.lightning.logger_util import wb_log_metrics -from disent.util.seeds import seed -from disent.util.strings.fmt import make_box_str -from experiment.run import hydra_append_progress_callback -from experiment.run import hydra_get_gpus -from experiment.run import hydra_make_logger -from experiment.util.hydra_utils import make_non_strict - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# EXP # -# ========================================================================= # - - -def disentangle_loss( - batch: torch.Tensor, - factors: torch.Tensor, - num_pairs: int, - f_idxs: Optional[List[int]] = None, - loss_fn: str = 'mse', - mean_dtype=None, -) -> torch.Tensor: - assert len(batch) == len(factors) - assert batch.ndim == 4 - assert factors.ndim == 2 - # random pairs - ia, ib = torch.randint(0, len(batch), size=(2, num_pairs), device=batch.device) - # get pairwise distances - b_dists = H.pairwise_loss(batch[ia], batch[ib], mode=loss_fn, mean_dtype=mean_dtype) # avoid precision errors - # compute factor distances - if f_idxs is not None: - f_dists = torch.abs(factors[ia][:, f_idxs] - factors[ib][:, f_idxs]).sum(dim=-1) - else: - f_dists = torch.abs(factors[ia] - factors[ib]).sum(dim=-1) - # optimise metric - loss = spearman_rank_loss(b_dists, -f_dists) # decreasing overlap should mean increasing factor dist - return loss - - -class DisentangleModule(DisentLightningModule): - - def __init__( - self, - model, - hparams, - disentangle_factor_idxs: Sequence[int] = None - ): - super().__init__() - self.model = model - self.hparams = hparams - self._disentangle_factors = None if (disentangle_factor_idxs is None) else np.array(disentangle_factor_idxs) - - def configure_optimizers(self): - return H.make_optimizer(self, name=self.hparams.optimizer.name, lr=self.hparams.optimizer.lr, weight_decay=self.hparams.optimizer.weight_decay) - - def training_step(self, batch, batch_idx): - (batch,), (factors,) = batch['x_targ'], batch['factors'] - # feed forward batch - aug_batch = self.model(batch) - # compute pairwise distances of factors and batch, and optimize to correspond - loss = disentangle_loss( - batch=aug_batch, - factors=factors, - num_pairs=int(len(batch) * self.hparams.train.pairs_ratio), - f_idxs=self._disentangle_factors, - loss_fn=self.hparams.train.loss, - mean_dtype=torch.float64, - ) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - if hasattr(self.model, 'augment_loss'): - loss_aug = self.model.augment_loss(self) - else: - loss_aug = 0 - loss += loss_aug - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - self.log('loss', loss) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - return loss - - def forward(self, batch): - return self.model(batch) - - -# ========================================================================= # -# MAIN # -# ========================================================================= # - - -class Kernel(DisentModule): - def __init__(self, radius: int = 33, channels: int = 1, offset: float = 0.0, scale: float = 0.001, train_symmetric_regularise: bool = True, train_norm_regularise: bool = True, train_nonneg_regularise: bool = True): - super().__init__() - assert channels in (1, 3) - kernel = torch.randn(1, channels, 2*radius+1, 2*radius+1, dtype=torch.float32) - kernel = offset + kernel * scale - # normalise - if train_nonneg_regularise: - kernel = torch.abs(kernel) - if train_norm_regularise: - kernel = kernel / kernel.sum(dim=[-1, -2], keepdim=True) - # store - self._kernel = Parameter(kernel) - # regularise options - self._train_symmetric_regularise = train_symmetric_regularise - self._train_norm_regularise = train_norm_regularise - self._train_nonneg_regularise = train_nonneg_regularise - - def forward(self, xs): - return torch_conv2d_channel_wise_fft(xs, self._kernel) - - def make_train_periodic_callback(self, cfg, dataset) -> BaseCallbackPeriodic: - class ImShowCallback(BaseCallbackPeriodic): - def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - # get kernel image - kernel = H.to_img(pl_module.model._kernel[0], scale=True).numpy() - # augment function - def augment_fn(batch): - return H.to_imgs(pl_module.forward(batch.to(pl_module.device)), scale=True) - # get augmented traversals - with torch.no_grad(): - orig_wandb_image, orig_wandb_animation = H.visualize_dataset_traversal(dataset) - augm_wandb_image, augm_wandb_animation = H.visualize_dataset_traversal(dataset, augment_fn=augment_fn, data_mode='input') - # log images to WANDB - wb_log_metrics(trainer.logger, { - 'kernel': wandb.Image(kernel), - 'traversal_img_orig': orig_wandb_image, 'traversal_animation_orig': orig_wandb_animation, - 'traversal_img_augm': augm_wandb_image, 'traversal_animation_augm': augm_wandb_animation, - }) - return ImShowCallback(every_n_steps=cfg.exp.show_every_n_steps, begin_first_step=True) - - def augment_loss(self, framework: DisentLightningModule): - augment_loss = 0 - # symmetric loss - if self._train_symmetric_regularise: - k, kt = self._kernel[0], torch.transpose(self._kernel[0], -1, -2) - loss_symmetric = 0 - loss_symmetric += H.unreduced_loss(torch.flip(k, dims=[-1]), k, mode='mae').mean() - loss_symmetric += H.unreduced_loss(torch.flip(k, dims=[-2]), k, mode='mae').mean() - loss_symmetric += H.unreduced_loss(torch.flip(k, dims=[-1]), kt, mode='mae').mean() - loss_symmetric += H.unreduced_loss(torch.flip(k, dims=[-2]), kt, mode='mae').mean() - # log loss - framework.log('loss_symmetric', loss_symmetric) - # final loss - augment_loss += loss_symmetric - # sum of 1 loss, per channel - if self._train_norm_regularise: - k = self._kernel[0] - # sum over W & H resulting in: (C, W, H) -> (C,) - channel_sums = k.sum(dim=[-1, -2]) - channel_loss = H.unreduced_loss(channel_sums, torch.ones_like(channel_sums), mode='mae') - norm_loss = channel_loss.mean() - # log loss - framework.log('loss_norm', norm_loss) - # final loss - augment_loss += norm_loss - # no negatives regulariser - if self._train_nonneg_regularise: - k = self._kernel[0] - nonneg_loss = torch.abs(k[k < 0].sum()) - # log loss - framework.log('loss_non_negative', nonneg_loss) - # regularise negatives - augment_loss += nonneg_loss - # return! - return augment_loss - - -# ========================================================================= # -# Run Hydra # -# ========================================================================= # - - -ROOT_DIR = os.path.abspath(__file__ + '/../../../..') - - -@hydra.main(config_path=os.path.join(ROOT_DIR, 'experiment/config'), config_name="config_adversarial_kernel") -def run_disentangle_dataset_kernel(cfg): - cfg = make_non_strict(cfg) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # TODO: some of this code is duplicated between this and the main experiment run.py - # check CUDA setting - cfg.trainer.setdefault('cuda', 'try_cuda') - gpus = hydra_get_gpus(cfg) - # CREATE LOGGER - logger = hydra_make_logger(cfg) - # TRAINER CALLBACKS - callbacks = [] - hydra_append_progress_callback(callbacks, cfg) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - seed(disent.util.seeds.seed) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # initialise dataset and get factor names to disentangle - dataset = H.make_dataset(cfg.data.name, factors=True, data_root=cfg.default_settings.storage.data_root) - disentangle_factor_idxs = dataset.gt_data.normalise_factor_idxs(cfg.kernel.disentangle_factors) - cfg.kernel.disentangle_factors = tuple(dataset.gt_data.factor_names[i] for i in disentangle_factor_idxs) - log.info(f'Dataset has ground-truth factors: {dataset.gt_data.factor_names}') - log.info(f'Chosen ground-truth factors are: {tuple(cfg.kernel.disentangle_factors)}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # print everything - log.info('Final Config' + make_box_str(OmegaConf.to_yaml(cfg))) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - dataloader = DataLoader( - dataset, - batch_sampler=H.StochasticBatchSampler(dataset, batch_size=cfg.dataset.batch_size), - num_workers=cfg.dataset.num_workers, - pin_memory=cfg.dataset.pin_memory, - ) - model = Kernel(radius=cfg.kernel.radius, channels=cfg.kernel.channels, offset=0.002, scale=0.01, train_symmetric_regularise=cfg.kernel.regularize_symmetric, train_norm_regularise=cfg.kernel.regularize_norm, train_nonneg_regularise=cfg.kernel.regularize_nonneg) - callbacks.append(model.make_train_periodic_callback(cfg, dataset=dataset)) - framework = DisentangleModule(model, cfg, disentangle_factor_idxs=disentangle_factor_idxs) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - if framework.logger: - framework.logger.log_hyperparams(framework.hparams) - # train - trainer = pl.Trainer( - log_every_n_steps=cfg.log.setdefault('log_every_n_steps', 50), - flush_logs_every_n_steps=cfg.log.setdefault('flush_logs_every_n_steps', 100), - logger=logger, - callbacks=callbacks, - gpus=1 if gpus else 0, - max_epochs=cfg.trainer.setdefault('epochs', None), - max_steps=cfg.trainer.setdefault('steps', 10000), - progress_bar_refresh_rate=0, # ptl 0.9 - terminate_on_nan=True, # we do this here so we don't run the final metrics - checkpoint_callback=False, - ) - trainer.fit(framework, dataloader) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # save kernel - if cfg.exp.rel_save_dir is not None: - assert not os.path.isabs(cfg.exp.rel_save_dir), f'rel_save_dir must be relative: {repr(cfg.exp.rel_save_dir)}' - save_dir = os.path.join(ROOT_DIR, cfg.exp.rel_save_dir) - assert os.path.isabs(save_dir), f'save_dir must be absolute: {repr(save_dir)}' - # save kernel - H.torch_write(os.path.join(save_dir, cfg.exp.save_name), framework.model._kernel) - - -# ========================================================================= # -# Entry Point # -# ========================================================================= # - - -if __name__ == '__main__': - # HYDRA: - # run experiment (12min * 4*8*2) / 60 ~= 12 hours - # but speeds up as kernel size decreases, so might be shorter - # EXP ARGS: - # $ ... -m optimizer.weight_decay=1e-4,0.0 kernel.radius=63,55,47,39,31,23,15,7 dataset.spacing=8,4,2,1 - run_disentangle_dataset_kernel() diff --git a/research/e05_disentangle_kernel/submit_03.sh b/research/e05_disentangle_kernel/submit_03.sh deleted file mode 100644 index bd5e6e6a..00000000 --- a/research/e05_disentangle_kernel/submit_03.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-03__kernel-disentangle-xy" -export PARTITION="stampede" -export PARALLELISM=32 -export PY_RUN_FILE='experiment/exp/05_adversarial_data/run_03_train_disentangle_kernel.py' - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# 1 * (2*8*4) == 64 -submit_sweep \ - optimizer.weight_decay=1e-4,0.0 \ - kernel.radius=63,55,47,39,31,23,15,7 \ - data.name=xysquares_8x8,xysquares_4x4,xysquares_2x2,xysquares_1x1 diff --git a/research/e06_adversarial_data/deprecated/run_01_gen_adversarial_disk.py b/research/e06_adversarial_data/deprecated/run_01_gen_adversarial_disk.py deleted file mode 100644 index 4d323ab5..00000000 --- a/research/e06_adversarial_data/deprecated/run_01_gen_adversarial_disk.py +++ /dev/null @@ -1,497 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -""" -Generate an adversarial dataset -- Stores the mutating dataset on disk -- Loads minibatches from disk that are optimized and the saved back to the disk -- No model is used, images are directly optimized against eachother, could decay in some cases? - -This is quite memory efficient, but it is quite old! -- Should probably be re-written using ray -""" - - -import logging -import multiprocessing.synchronize -import os -from concurrent.futures import Executor -from concurrent.futures import Future -from concurrent.futures import ProcessPoolExecutor -from typing import Optional -from typing import Sequence - -import h5py -import numpy as np -import psutil -import torch -from tqdm import tqdm - -import research.util as H -from disent.util.deprecate import deprecated -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.profiling import Timer -from disent.util.seeds import seed - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# losses # -# ========================================================================= # - - -def stochastic_const_loss(pred: torch.Tensor, mask: torch.Tensor, num_pairs: int, num_samples: int, loss='mse', reg_out_of_bounds=True, top_k: int = None, constant_targ: float = None) -> torch.Tensor: - ia, ib = torch.randint(0, len(pred), size=(2, num_samples), device=pred.device) - # constant dist loss - x_ds = (H.unreduced_loss(pred[ia], pred[ib], mode=loss) * mask[None, ...]).mean(dim=(-3, -2, -1)) - # compute constant loss - if constant_targ is None: - iA, iB = torch.randint(0, len(x_ds), size=(2, num_pairs), device=pred.device) - lcst = H.unreduced_loss(x_ds[iA], x_ds[iB], mode=loss) - else: - lcst = H.unreduced_loss(x_ds, torch.full_like(x_ds, constant_targ), mode=loss) - # aggregate constant loss - if top_k is None: - lcst = lcst.mean() - else: - lcst = torch.topk(lcst, k=top_k, largest=True).values.mean() - # values over the required range - if reg_out_of_bounds: - m = torch.nan_to_num((0 - pred[pred < 0]) ** 2, nan=0).mean() - M = torch.nan_to_num((pred[pred > 1] - 1) ** 2, nan=0).mean() - mM = m + M - else: - mM = 0. - # done! - return mM + lcst - - -# ========================================================================= # -# h5py dataset helper # -# ========================================================================= # - - -NAME_DATA = 'data' -NAME_VISITS = 'visits' -NAME_OBS = 'x_targ' - -_SAVE_TYPE_LOOKUP = { - 'uint8': torch.uint8, - 'float16': torch.float16, - 'float32': torch.float32, -} - -SAVE_TYPE = 'float16' -assert SAVE_TYPE in _SAVE_TYPE_LOOKUP - - -def _make_hdf5_dataset(path, dataset, overwrite_mode: str = 'continue') -> str: - path = ensure_parent_dir_exists(path) - # get read/write mode - if overwrite_mode == 'overwrite': - rw_mode = 'w' # create new file, overwrite if exists - elif overwrite_mode == 'fail': - rw_mode = 'x' # create new file, fail if exists - elif overwrite_mode == 'continue': - rw_mode = 'a' # create if missing, append if exists - # clear file consistency flags - # if clear_consistency_flags: - # if os.path.isfile(path): - # cmd = ["h5clear", "-s", "'{path}'"] - # print(f'clearing file consistency flags: {" ".join(cmd)}') - # try: - # subprocess.check_output(cmd) - # except FileNotFoundError: - # raise FileNotFoundError('h5clear utility is not installed!') - else: - raise KeyError(f'invalid overwrite_mode={repr(overwrite_mode)}') - # open in read write mode - log.info(f'Opening hdf5 dataset: overwrite_mode={repr(overwrite_mode)} exists={repr(os.path.exists(path))} path={repr(path)}') - with h5py.File(path, rw_mode, libver='earliest') as f: - # get data - num_obs = len(dataset) - obs_shape = dataset[0][NAME_OBS][0].shape - # make dset - if NAME_DATA not in f: - f.create_dataset( - NAME_DATA, - shape=(num_obs, *obs_shape), - dtype=SAVE_TYPE, - chunks=(1, *obs_shape), - track_times=False, - ) - # make set_dset - if NAME_VISITS not in f: - f.create_dataset( - NAME_VISITS, - shape=(num_obs,), - dtype='int64', - chunks=(1,), - track_times=False, - ) - return path - - -# def _read_hdf5_batch(h5py_path: str, idxs, return_visits=False): -# batch, visits = [], [] -# with h5py.File(h5py_path, 'r', swmr=True) as f: -# for i in idxs: -# visits.append(f[NAME_VISITS][i]) -# obs = torch.as_tensor(f[NAME_DATA][i], dtype=torch.float32) -# if SAVE_TYPE == 'uint8': -# obs /= 255 -# batch.append(obs) -# # return values -# if return_visits: -# return torch.stack(batch, dim=0), np.array(visits, dtype=np.int64) -# else: -# return torch.stack(batch, dim=0) - - -def _load_hdf5_batch(dataset, h5py_path: str, idxs, initial_noise: Optional[float] = None, return_visits=True): - """ - Load a batch from the disk -- always return float32 - - Can be used by multiple threads at a time. - - returns an item from the original dataset if an - observation has not been saved into the hdf5 dataset yet. - """ - batch, visits = [], [] - with h5py.File(h5py_path, 'r', swmr=True) as f: - for i in idxs: - v = f[NAME_VISITS][i] - if v > 0: - obs = torch.as_tensor(f[NAME_DATA][i], dtype=torch.float32) - if SAVE_TYPE == 'uint8': - obs /= 255 - else: - (obs,) = dataset[i][NAME_OBS] - obs = obs.to(torch.float32) - if initial_noise is not None: - obs += (torch.randn_like(obs) * initial_noise) - batch.append(obs) - visits.append(v) - # stack and check values - batch = torch.stack(batch, dim=0) - assert batch.dtype == torch.float32 - # return values - if return_visits: - return batch, np.array(visits, dtype=np.int64) - else: - return batch - - -def _save_hdf5_batch(h5py_path: str, batch, idxs): - """ - Save a float32 batch to disk. - - Can only be used by one thread at a time! - """ - assert batch.dtype == torch.float32 - with h5py.File(h5py_path, 'r+', libver='earliest') as f: - for obs, idx in zip(batch, idxs): - if SAVE_TYPE == 'uint8': - f[NAME_DATA][idx] = torch.clamp(torch.round(obs * 255), 0, 255).to(torch.uint8) - else: - f[NAME_DATA][idx] = obs.to(_SAVE_TYPE_LOOKUP[SAVE_TYPE]) - f[NAME_VISITS][idx] += 1 - - -# ========================================================================= # -# multiproc h5py dataset helper # -# ========================================================================= # - - -class FutureList(object): - def __init__(self, futures: Sequence[Future]): - self._futures = futures - - def result(self): - return [future.result() for future in self._futures] - - -# ========================================================================= # -# multiproc h5py dataset helper # -# ========================================================================= # - - -# SUBMIT: - - -def _submit_load_batch_futures(executor: Executor, num_splits: int, dataset, h5py_path: str, idxs, initial_noise: Optional[float] = None) -> FutureList: - return FutureList([ - executor.submit(__inner__load_batch, dataset=dataset, h5py_path=h5py_path, idxs=idxs, initial_noise=initial_noise) - for idxs in np.array_split(idxs, num_splits) - ]) - - -def _submit_save_batch(executor: Executor, h5py_path: str, batch, idxs) -> Future: - return executor.submit(__inner__save_batch, h5py_path=h5py_path, batch=batch, idxs=idxs) - - -NUM_WORKERS = psutil.cpu_count() -_BARRIER = None - - -def __inner__load_batch(dataset, h5py_path: str, idxs, initial_noise: Optional[float] = None): - _BARRIER.wait() - result = _load_hdf5_batch(dataset=dataset, h5py_path=h5py_path, idxs=idxs, initial_noise=initial_noise) - _BARRIER.wait() - return result - - -def __inner__save_batch(h5py_path, batch, idxs): - _save_hdf5_batch(h5py_path=h5py_path, batch=batch, idxs=idxs) - - -# WAIT: - - -def _wait_for_load_future(future: FutureList): - with Timer() as t: - xs, visits = zip(*future.result()) - xs = torch.cat(xs, dim=0) - visits = np.concatenate(visits, axis=0).mean(dtype=np.float32) - return (xs, visits), t - - -def _wait_for_save_future(future: Future): - with Timer() as t: - future.result() - return t - - -# ========================================================================= # -# adversarial dataset generator # -# ========================================================================= # - - -def run_generate_and_save_adversarial_dataset_mp( - dataset_name: str = 'shapes3d', - dataset_load_into_memory: bool = False, - optimizer: str = 'adam', - lr: float = 1e-2, - obs_masked: bool = True, - obs_initial_noise: Optional[float] = None, - loss_fn: str = 'mse', - batch_size: int = 1024*12, # approx - batch_sample_mode: str = 'shuffle', # range, shuffle, random - loss_num_pairs: int = 1024*4, - loss_num_samples: int = 1024*4*2, # only applies if loss_const_targ=None - loss_top_k: Optional[int] = None, - loss_const_targ: Optional[float] = 0.1, # replace stochastic pairwise constant loss with deterministic loss target - loss_reg_out_of_bounds: bool = False, - train_epochs: int = 8, - train_optim_steps: int = 125, - # skipped params - save_folder: str = 'out/overlap', - save_prefix: str = '', - overwrite_mode: str = 'fail', # continue, overwrite, fail - seed_: Optional[int] = 777, -) -> str: - # checks - if obs_initial_noise is not None: - assert not obs_masked, '`obs_masked` cannot be `True`, if using initial noise, ie. `obs_initial_noise is not None`' - - # deterministic! - seed(seed_) - - # ↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓ # - # make dataset - dataset = H.make_dataset(dataset_name, load_into_memory=dataset_load_into_memory, load_memory_dtype=torch.float16) - # get save path - assert not ('/' in save_prefix or '\\' in save_prefix) - name = H.params_as_string(H.get_caller_params(exclude=["save_folder", "save_prefix", "overwrite_mode", "seed_"])) - path = _make_hdf5_dataset(os.path.join(save_folder, f'{save_prefix}{name}.hdf5'), dataset=dataset, overwrite_mode=overwrite_mode) - # ↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑ # - - train_batches = (len(dataset) + batch_size - 1) // batch_size - # loop vars & progress bar - save_time = Timer() - prog = tqdm(total=train_epochs * train_batches * train_optim_steps, postfix={'loss': 0.0, '💯': 0.0, '🔍': 'N/A', '💾': 'N/A'}, ncols=100) - # multiprocessing pool - global _BARRIER # TODO: this is a hack and should be unique to each run - _BARRIER = multiprocessing.Barrier(NUM_WORKERS) - executor = ProcessPoolExecutor(NUM_WORKERS) - - # EPOCHS: - for e in range(train_epochs): - # generate batches - batch_idxs = H.generate_epoch_batch_idxs(num_obs=len(dataset), num_batches=train_batches, mode=batch_sample_mode) - # first data load - load_future = _submit_load_batch_futures(executor, num_splits=NUM_WORKERS, dataset=dataset, h5py_path=path, idxs=batch_idxs[0], initial_noise=obs_initial_noise) - - # TODO: log to WANDB - # TODO: SAMPLING STRATEGY MIGHT NEED TO CHANGE! - # - currently random pairs are generated, but the pairs that matter are the nearby ones. - # - sample pairs that increase and decrease along an axis - # - sample pairs that are nearby according to the factor distance metric - - # BATCHES: - for n in range(len(batch_idxs)): - # ↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓ # - # get batch -- transfer to gpu is the bottleneck - (x, visits), load_time = _wait_for_load_future(load_future) - x = x.cuda().requires_grad_(True) - # ↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑ # - - # queue loading an extra batch - if (n+1) < len(batch_idxs): - load_future = _submit_load_batch_futures(executor, num_splits=NUM_WORKERS, dataset=dataset, h5py_path=path, idxs=batch_idxs[n + 1], initial_noise=obs_initial_noise) - - # ↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓ # - # make optimizers - mask = H.make_changed_mask(x, masked=obs_masked) - optim = H.make_optimizer(x, name=optimizer, lr=lr) - # ↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑ # - - # OPTIMIZE: - for _ in range(train_optim_steps): - # final loss & update - # ↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓=↓ # - loss = stochastic_const_loss(x, mask, num_pairs=loss_num_pairs, num_samples=loss_num_samples, loss=loss_fn, reg_out_of_bounds=loss_reg_out_of_bounds, top_k=loss_top_k, constant_targ=loss_const_targ) - H.step_optimizer(optim, loss) - # ↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑=↑ # - - # update progress bar - logs = {'loss': float(loss), '💯': visits, '🔍': load_time.pretty, '💾': save_time.pretty} - prog.update() - prog.set_postfix(logs) - - # save optimized minibatch - if n > 0: - save_time = _wait_for_save_future(save_future) - save_future = _submit_save_batch(executor, h5py_path=path, batch=x.detach().cpu(), idxs=batch_idxs[n]) - - # final save - save_time = _wait_for_save_future(save_future) - - # cleanup all - executor.shutdown() - # return the path to the dataset - return path - - -# ========================================================================= # -# test adversarial dataset generator # -# ========================================================================= # - - -@deprecated('Replaced with run_02_gen_adversarial_dataset_approx') -def run_generate_adversarial_data( - dataset: str ='shapes3d', - factor: str ='wall_hue', - factor_mode: str = 'sample_random', - optimizer: str ='radam', - lr: float = 1e-2, - obs_num: int = 1024 * 10, - obs_noise_weight: float = 0, - obs_masked: bool = True, - loss_fn: str = 'mse', - loss_num_pairs: int = 4096, - loss_num_samples: int = 4096*2, # only applies if loss_const_targ=None - loss_top_k: int = None, - loss_const_targ: float = None, # replace stochastic pairwise constant loss with deterministic loss target - loss_reg_out_of_bounds: bool = False, - train_steps: int = 2000, - display_period: int = 500, -): - seed(777) - # make dataset - dataset = H.make_dataset(dataset) - # make batches - factors = H.sample_factors(dataset, num_obs=obs_num, factor_mode=factor_mode, factor=factor) - x = dataset.dataset_batch_from_factors(factors, 'target') - # make tensors to optimize - if torch.cuda.is_available(): - x = x.cuda() - x = torch.tensor(x + torch.randn_like(x) * obs_noise_weight, requires_grad=True) - # generate mask - mask = H.make_changed_mask(x, masked=obs_masked) - H.plt_imshow(H.to_img(mask.to(torch.float32)), show=True) - # make optimizer - optimizer = H.make_optimizer(x, name=optimizer, lr=lr) - - # optimize differences according to loss - prog = tqdm(range(train_steps+1), postfix={'loss': 0.0}) - for i in prog: - # final loss - loss = stochastic_const_loss(x, mask, num_pairs=loss_num_pairs, num_samples=loss_num_samples, loss=loss_fn, reg_out_of_bounds=loss_reg_out_of_bounds, top_k=loss_top_k, constant_targ=loss_const_targ) - # update variables - H.step_optimizer(optimizer, loss) - if i % display_period == 0: - log.warning(f'visualisation of `x[:9]` was disabled') - prog.set_postfix({'loss': float(loss)}) - - -# ========================================================================= # -# entrypoint # -# ========================================================================= # - -# TODO: add WANDB support for visualisation of dataset -# TODO: add graphing of visual overlap like exp 01 - -def main(): - logging.basicConfig(level=logging.INFO, format='(%(asctime)s) %(name)s:%(lineno)d [%(levelname)s]: %(message)s') - - paths = [] - for i, kwargs in enumerate([ - # dict(save_prefix='e128__fixed_unmask_const_', obs_masked=False, loss_const_targ=0.1, obs_initial_noise=None, optimizer='adam', dataset_name='cars3d'), - # dict(save_prefix='e128__fixed_unmask_const_', obs_masked=False, loss_const_targ=0.1, obs_initial_noise=None, optimizer='adam', dataset_name='smallnorb'), - # dict(save_prefix='e128__fixed_unmask_randm_', obs_masked=False, loss_const_targ=None, obs_initial_noise=None, optimizer='adam', dataset_name='cars3d'), - # dict(save_prefix='e128__fixed_unmask_randm_', obs_masked=False, loss_const_targ=None, obs_initial_noise=None, optimizer='adam', dataset_name='smallnorb'), - ]): - # generate dataset - try: - path = run_generate_and_save_adversarial_dataset_mp( - train_epochs=128, - train_optim_steps=175, - seed_=777, - overwrite_mode='overwrite', - dataset_load_into_memory=True, - lr=5e-3, - # batch_sample_mode='range', - **kwargs - ) - paths.append(path) - except Exception as e: - log.error(f'[{i}] FAILED RUN: {e} -- {repr(kwargs)}', exc_info=True) - # load some samples and display them - try: - log.warning(f'visualisation of `_read_hdf5_batch(paths[-1], display_idxs)` was disabled') - except Exception as e: - log.warning(f'[{i}] FAILED SHOW: {e} -- {repr(kwargs)}') - - for path in paths: - print(path) - - -# ========================================================================= # -# main # -# ========================================================================= # - - -if __name__ == '__main__': - main() diff --git a/research/e06_adversarial_data/deprecated/run_02_adv_dataset.sh b/research/e06_adversarial_data/deprecated/run_02_adv_dataset.sh deleted file mode 100644 index e864de99..00000000 --- a/research/e06_adversarial_data/deprecated/run_02_adv_dataset.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# get the path to the script -PARENT_DIR="$(dirname "$(realpath -s "$0")")" -ROOT_DIR="$(dirname "$(dirname "$(dirname "$PARENT_DIR")")")" - -# TODO: fix this! -# TODO: this is out of date -PYTHONPATH="$ROOT_DIR" python3 "$PARENT_DIR/run_02_gen_adversarial_dataset.py" \ - -m \ - framework.sampler_name=same_k,close_far,same_factor,random_bb \ - framework.loss_mode=self,const,invert \ - framework.dataset_name=cars3d,smallnorb diff --git a/research/e06_adversarial_data/deprecated/run_02_gen_adversarial_dataset.py b/research/e06_adversarial_data/deprecated/run_02_gen_adversarial_dataset.py deleted file mode 100644 index 1260233b..00000000 --- a/research/e06_adversarial_data/deprecated/run_02_gen_adversarial_dataset.py +++ /dev/null @@ -1,436 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -Generate an adversarial dataset -- images are directly optimized against each other, could decay in some cases? -- All data is stored in memory, with minibatches taken and optimized. -""" - -import logging -import os -import warnings -from datetime import datetime -from typing import Iterator -from typing import List -from typing import Optional -from typing import Sequence - -import hydra -import numpy as np -import pytorch_lightning as pl -import torch -import wandb -from omegaconf import OmegaConf -from torch.utils.data import DataLoader -from torch.utils.data import IterableDataset -from torch.utils.data.dataset import T_co - -import research.util as H -from disent.dataset import DisentDataset -from disent.dataset.sampling import BaseDisentSampler -from disent.dataset.util.hdf5 import H5Builder -from disent.util import to_numpy -from disent.util.deprecate import deprecated -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.lightning.callbacks import BaseCallbackPeriodic -from disent.util.lightning.callbacks import LoggerProgressCallback -from disent.util.lightning.logger_util import wb_log_metrics -from disent.util.math.random import random_choice_prng -from disent.util.seeds import seed -from disent.util.seeds import TempNumpySeed -from disent.util.strings.fmt import bytes_to_human -from disent.util.strings.fmt import make_box_str -from disent.util.visualize.vis_util import make_image_grid -from experiment.run import hydra_get_callbacks -from experiment.run import hydra_get_gpus -from experiment.run import hydra_make_logger -from experiment.util.hydra_utils import make_non_strict -from experiment.util.run_utils import log_error_and_exit -from research.e06_adversarial_data.util_gen_adversarial_dataset import adversarial_loss -from research.e06_adversarial_data.util_gen_adversarial_dataset import make_adversarial_sampler - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# adversarial dataset generator # -# ========================================================================= # - - -class AdversarialModel(pl.LightningModule): - - def __init__( - self, - # optimizer options - optimizer_name: str = 'sgd', - optimizer_lr: float = 5e-2, - optimizer_kwargs: Optional[dict] = None, - # dataset config options - dataset_name: str = 'cars3d', - dataset_num_workers: int = min(os.cpu_count(), 16), - dataset_batch_size: int = 1024, # approx - data_root: str = 'data/dataset', - # data_load_into_memory: bool = False, - # adversarial loss options - adversarial_mode: str = 'self', - adversarial_swapped: bool = False, - adversarial_masking: bool = False, - adversarial_top_k: Optional[int] = None, - pixel_loss_mode: str = 'mse', - # loss extras - # loss_adversarial_weight: Optional[float] = 1.0, - # loss_same_stats_weight: Optional[float] = 0.0, - # loss_similarity_weight: Optional[float] = 0.0, - # loss_out_of_bounds_weight: Optional[float] = 0.0, - # sampling config - sampler_name: str = 'close_far', - # train options - train_batch_optimizer: bool = True, - train_dataset_fp16: bool = True, - train_is_gpu: bool = False, - # logging settings - # logging_scale_imgs: bool = False, - ): - super().__init__() - # check values - if train_dataset_fp16 and (not train_is_gpu): - warnings.warn('`train_dataset_fp16=True` is not supported on CPU, overriding setting to `False`') - train_dataset_fp16 = False - self._dtype_dst = torch.float32 - self._dtype_src = torch.float16 if train_dataset_fp16 else torch.float32 - # modify hparams - if optimizer_kwargs is None: - optimizer_kwargs = {} - # save hparams - self.save_hyperparameters() - # variables - self.dataset: DisentDataset = None - self.array: torch.Tensor = None - self.sampler: BaseDisentSampler = None - - # ================================== # - # setup # - # ================================== # - - def prepare_data(self) -> None: - # create dataset - self.dataset = H.make_dataset(self.hparams.dataset_name, load_into_memory=True, load_memory_dtype=self._dtype_src, data_root=self.hparams.data_root) - # load dataset into memory as fp16 - if self.hparams.train_batch_optimizer: - self.array = self.dataset.gt_data.array - else: - self.array = torch.nn.Parameter(self.dataset.gt_data.array, requires_grad=True) # move with model to correct device - # create sampler - self.sampler = make_adversarial_sampler(self.hparams.sampler_name) - self.sampler.init(self.dataset.gt_data) - - def _make_optimizer(self, params): - return H.make_optimizer( - params, - name=self.hparams.optimizer_name, - lr=self.hparams.optimizer_lr, - **self.hparams.optimizer_kwargs, - ) - - def configure_optimizers(self): - if self.hparams.train_batch_optimizer: - return None - else: - return self._make_optimizer(self.array) - - # ================================== # - # train step # - # ================================== # - - def training_step(self, batch, batch_idx): - # get indices - (a_idx, p_idx, n_idx) = batch['idx'] - # generate batches & transfer to correct device - if self.hparams.train_batch_optimizer: - (a_x, p_x, n_x), (params, param_idxs, optimizer) = self._load_batch(a_idx, p_idx, n_idx) - else: - a_x = self.array[a_idx] - p_x = self.array[p_idx] - n_x = self.array[n_idx] - # compute loss - loss = adversarial_loss( - ys=(a_x, p_x, n_x), - xs=None, - adversarial_mode=self.hparams.adversarial_mode, - adversarial_swapped=self.hparams.adversarial_swapped, - adversarial_masking=self.hparams.adversarial_masking, - adversarial_top_k=self.hparams.adversarial_top_k, - pixel_loss_mode=self.hparams.pixel_loss_mode, - ) - # log results - self.log_dict({ - 'loss': loss, - 'adv_loss': loss, - }, prog_bar=True) - # done! - if self.hparams.train_batch_optimizer: - self._update_with_batch(loss, params, param_idxs, optimizer) - return None - else: - return loss - - # ================================== # - # optimizer for each batch mode # - # ================================== # - - def _load_batch(self, a_idx, p_idx, n_idx): - with torch.no_grad(): - # get all indices - all_indices = np.stack([ - a_idx.detach().cpu().numpy(), - p_idx.detach().cpu().numpy(), - n_idx.detach().cpu().numpy(), - ], axis=0) - # find unique values - param_idxs, inverse_indices = np.unique(all_indices.flatten(), return_inverse=True) - inverse_indices = inverse_indices.reshape(all_indices.shape) - # load data with values & move to gpu - # - for batch size (256*3, 3, 64, 64) with num_workers=0, this is 5% faster - # than .to(device=self.device, dtype=DST_DTYPE) in one call, as we reduce - # the memory overhead in the transfer. This does slightly increase the - # memory usage on the target device. - # - for batch size (1024*3, 3, 64, 64) with num_workers=12, this is 15% faster - # but consumes slightly more memory: 2492MiB vs. 2510MiB - params = self.array[param_idxs].to(device=self.device).to(dtype=self._dtype_dst) - # make params and optimizer - params = torch.nn.Parameter(params, requires_grad=True) - optimizer = self._make_optimizer(params) - # get batches -- it is ok to index by a numpy array without conversion - a_x = params[inverse_indices[0, :]] - p_x = params[inverse_indices[1, :]] - n_x = params[inverse_indices[2, :]] - # return values - return (a_x, p_x, n_x), (params, param_idxs, optimizer) - - def _update_with_batch(self, loss, params, param_idxs, optimizer): - with TempNumpySeed(777): - std, mean = torch.std_mean(self.array[np.random.randint(0, len(self.array), size=128)]) - std, mean = std.cpu().numpy().tolist(), mean.cpu().numpy().tolist() - self.log_dict({'approx_mean': mean, 'approx_std': std}, prog_bar=True) - # backprop - H.step_optimizer(optimizer, loss) - # save values to dataset - with torch.no_grad(): - self.array[param_idxs] = params.detach().cpu().to(self._dtype_src) - - # ================================== # - # dataset # - # ================================== # - - def train_dataloader(self): - # sampling in dataloader - sampler = self.sampler - data_len = len(self.dataset.gt_data) - # generate the indices in a multi-threaded environment -- this is not deterministic if num_workers > 0 - class SamplerIndicesDataset(IterableDataset): - def __getitem__(self, index) -> T_co: - raise RuntimeError('this should never be called on an iterable dataset') - def __iter__(self) -> Iterator[T_co]: - while True: - yield {'idx': sampler(np.random.randint(0, data_len))} - # create data loader! - return DataLoader( - SamplerIndicesDataset(), - batch_size=self.hparams.dataset_batch_size, - num_workers=self.hparams.dataset_num_workers, - shuffle=False, - ) - - def make_train_periodic_callbacks(self, cfg) -> Sequence[pl.Callback]: - class ImShowCallback(BaseCallbackPeriodic): - def do_step(this, trainer: pl.Trainer, pl_module: pl.LightningModule): - if self.dataset is None: - log.warning('dataset not initialized, skipping visualisation') - # get dataset images - with TempNumpySeed(777): - # get scaling values - samples = self.dataset.dataset_sample_batch(num_samples=128, mode='raw').to(torch.float32) - m, M = float(torch.min(samples)), float(torch.max(samples)) - # add transform to dataset - self.dataset._transform = lambda x: H.to_img((x.to(torch.float32) - m) / (M - m)) # this is hacky, scale values to [0, 1] then to [0, 255] - # get images - image = make_image_grid(self.dataset.dataset_sample_batch(num_samples=16, mode='input')) - # get augmented traversals - with torch.no_grad(): - wandb_image, wandb_animation = H.visualize_dataset_traversal(self.dataset, data_mode='input', output_wandb=True) - # log images to WANDB - wb_log_metrics(trainer.logger, { - 'random_images': wandb.Image(image), - 'traversal_image': wandb_image, 'traversal_animation': wandb_animation, - }) - return [ImShowCallback(every_n_steps=cfg.exp.show_every_n_steps, begin_first_step=True)] - - -# ========================================================================= # -# Run Hydra # -# ========================================================================= # - - -ROOT_DIR = os.path.abspath(__file__ + '/../../../..') - - -@deprecated('Replaced with run_02_gen_adversarial_dataset_approx') -def run_gen_adversarial_dataset(cfg): - time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') - log.info(f'Starting run at time: {time_string}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # cleanup from old runs: - try: - wandb.finish() - except: - pass - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - cfg = make_non_strict(cfg) - # - - - - - - - - - - - - - - - # - # check CUDA setting - gpus = hydra_get_gpus(cfg) - # create logger - logger = hydra_make_logger(cfg) - # create callbacks - callbacks: List[pl.Callback] = [c for c in hydra_get_callbacks(cfg) if isinstance(c, LoggerProgressCallback)] - # - - - - - - - - - - - - - - - # - # check save dirs - assert not os.path.isabs(cfg.settings.exp.rel_save_dir), f'rel_save_dir must be relative: {repr(cfg.settings.exp.rel_save_dir)}' - save_dir = os.path.join(ROOT_DIR, cfg.settings.exp.rel_save_dir) - assert os.path.isabs(save_dir), f'save_dir must be absolute: {repr(save_dir)}' - # - - - - - - - - - - - - - - - # - # get the logger and initialize - if logger is not None: - logger.log_hyperparams(cfg) - # print the final config! - log.info('Final Config' + make_box_str(OmegaConf.to_yaml(cfg))) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # | | | | | | | | | | | | | | | # - seed(cfg.settings.job.seed) - # | | | | | | | | | | | | | | | # - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # make framework - framework = AdversarialModel(train_is_gpu=cfg.trainer.cuda, **cfg.adv_system) - callbacks.extend(framework.make_train_periodic_callbacks(cfg)) - # train - trainer = pl.Trainer( - logger=logger, - callbacks=callbacks, - # cfg.dsettings.trainer - gpus=gpus, - # cfg.trainer - max_epochs=cfg.trainer.max_epochs, - max_steps=cfg.trainer.max_steps, - log_every_n_steps=cfg.trainer.log_every_n_steps, - flush_logs_every_n_steps=cfg.trainer.flush_logs_every_n_steps, - progress_bar_refresh_rate=cfg.trainer.progress_bar_refresh_rate, - prepare_data_per_node=cfg.trainer.prepare_data_per_node, - # we do this here so we don't run the final metrics - terminate_on_nan=True, - checkpoint_callback=False, - ) - trainer.fit(framework) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # get save paths - save_prefix = f'{cfg.settings.exp.save_prefix}_' if cfg.settings.exp.save_prefix else '' - save_path_data = os.path.join(save_dir, f'{save_prefix}{time_string}_{cfg.settings.job.name}', f'data.h5') - # create directories - if cfg.settings.exp.save_data: ensure_parent_dir_exists(save_path_data) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # compute standard deviation when saving and scale so - # that we have mean=0 and std=1 of the saved data! - with TempNumpySeed(777): - std, mean = torch.std_mean(framework.array[random_choice_prng(len(framework.array), size=2048, replace=False)]) - std, mean = float(std), float(mean) - log.info(f'normalizing saved dataset of shape: {tuple(framework.array.shape)} and dtype: {framework.array.dtype} with mean: {repr(mean)} and std: {repr(std)}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # save adversarial dataset - if cfg.settings.exp.save_data: - log.info(f'saving data to path: {repr(save_path_data)}') - # transfer to GPU - if torch.cuda.is_available(): - framework = framework.cuda() - # create new h5py file -- TODO: use this in other places! - with H5Builder(path=save_path_data, mode='atomic_w') as builder: - # this dataset is self-contained and can be loaded by SelfContainedHdf5GroundTruthData - # we normalize the values to have approx mean of 0 and std of 1 - builder.add_dataset_from_gt_data( - data=framework.dataset, # produces tensors - mutator=lambda x: np.moveaxis((to_numpy(x).astype('float32') - mean) / std, -3, -1).astype('float16'), # consumes tensors -> np.ndarrays - img_shape=(64, 64, None), - compression_lvl=9, - batch_size=32, - dtype='float16', - attrs=dict( - norm_mean=mean, - norm_std=std, - ) - ) - log.info(f'saved data size: {bytes_to_human(os.path.getsize(save_path_data))}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - - -# ========================================================================= # -# Entry Point # -# ========================================================================= # - - -if __name__ == '__main__': - - # BENCHMARK (batch_size=256, optimizer=sgd, lr=1e-2, dataset_num_workers=0): - # - batch_optimizer=False, gpu=True, fp16=True : [3168MiB/5932MiB, 3.32/11.7G, 5.52it/s] - # - batch_optimizer=False, gpu=True, fp16=False : [5248MiB/5932MiB, 3.72/11.7G, 4.84it/s] - # - batch_optimizer=False, gpu=False, fp16=True : [same as fp16=False] - # - batch_optimizer=False, gpu=False, fp16=False : [0003MiB/5932MiB, 4.60/11.7G, 1.05it/s] - # --------- - # - batch_optimizer=True, gpu=True, fp16=True : [1284MiB/5932MiB, 3.45/11.7G, 4.31it/s] - # - batch_optimizer=True, gpu=True, fp16=False : [1284MiB/5932MiB, 3.72/11.7G, 4.31it/s] - # - batch_optimizer=True, gpu=False, fp16=True : [same as fp16=False] - # - batch_optimizer=True, gpu=False, fp16=False : [0003MiB/5932MiB, 1.80/11.7G, 4.18it/s] - - # BENCHMARK (batch_size=1024, optimizer=sgd, lr=1e-2, dataset_num_workers=12): - # - batch_optimizer=True, gpu=True, fp16=True : [2510MiB/5932MiB, 4.10/11.7G, 4.75it/s, 20% gpu util] (to(device).to(dtype)) - # - batch_optimizer=True, gpu=True, fp16=True : [2492MiB/5932MiB, 4.10/11.7G, 4.12it/s, 19% gpu util] (to(device, dtype)) - - @hydra.main(config_path=os.path.join(ROOT_DIR, 'experiment/config'), config_name="config_adversarial_dataset") - def main(cfg): - try: - run_gen_adversarial_dataset(cfg) - except Exception as e: - # truncate error - err_msg = str(e) - err_msg = err_msg[:244] + ' ' if len(err_msg) > 244 else err_msg - # log something at least - log.error(f'exiting: experiment error | {err_msg}', exc_info=True) - - # EXP ARGS: - # $ ... -m dataset=smallnorb,shapes3d - try: - main() - except KeyboardInterrupt as e: - log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) - except Exception as e: - log_error_and_exit(err_type='hydra error', err_msg=str(e)) diff --git a/research/e06_adversarial_data/deprecated/run_03_check.py b/research/e06_adversarial_data/deprecated/run_03_check.py deleted file mode 100644 index d15c8eeb..00000000 --- a/research/e06_adversarial_data/deprecated/run_03_check.py +++ /dev/null @@ -1,86 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -""" -Check the adversarial data generated in previous exerpiments -- This is old and outdated... -- Should use `e01_visual_overlap/run_plot_traversal_dists.py` instead! -""" - - -import numpy as np -import torch -import torch.nn.functional as F -import torchvision -import matplotlib.pyplot as plt - -from disent.dataset.data import Shapes3dData -from research.util._data import AdversarialOptimizedData - - -if __name__ == '__main__': - - def ave_pairwise_dist(data, n_samples=1000): - """ - Get the average distance between observations in the dataset - """ - # get stats - diff = [] - for i in range(n_samples): - a, b = np.random.randint(0, len(data), size=2) - a, b = data[a], data[b] - diff.append(F.mse_loss(a, b, reduction='mean').item()) - return np.mean(diff) - - def plot_samples(data, name=None): - """ - Display random observations from the dataset - """ - # get image - img = torchvision.utils.make_grid([data[i*1000] for i in range(9)], nrow=3) - img = torch.moveaxis(img, 0, -1).numpy() - # plot - if name is not None: - plt.title(name) - plt.imshow(img) - plt.show() - - - def main(): - base_data = Shapes3dData(in_memory=False, prepare=True, transform=torchvision.transforms.ToTensor()) - plot_samples(base_data) - print(ave_pairwise_dist(base_data)) - - for path in [ - 'out/overlap/fixed_masked_const_shapes3d_adam_0.01_True_None_mse_12288_shuffle_5120_10240_None_0.1_False_8_125.hdf5', - 'out/overlap/fixed_masked_randm_shapes3d_adam_0.01_True_None_mse_12288_shuffle_5120_10240_None_None_False_8_125.hdf5', - 'out/overlap/noise_unmask_randm_shapes3d_adam_0.01_False_0.001_mse_12288_shuffle_5120_10240_None_None_False_8_125.hdf5', - 'out/overlap/noise_unmask_randm_shapes3d_adam_0.01_False_0.1_mse_12288_shuffle_5120_10240_None_None_False_8_125.hdf5', - ]: - data = AdversarialOptimizedData(path, base_data, transform=torchvision.transforms.ToTensor()) - plot_samples(data) - print(ave_pairwise_dist(data)) - - main() diff --git a/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck.py b/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck.py deleted file mode 100644 index 9aee9e08..00000000 --- a/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck.py +++ /dev/null @@ -1,585 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -This file generates pareto-optimal solutions to the multi-objective -optimisation problem of masking a dataset as to minimize some metric -for overlap, while maximizing the amount of data kept. - -- We solve this problem using the NSGA2 algorithm and save all the results - to disk to be loaded with `get_closest_mask` from `util_load_adversarial_mask.py` -""" - -import gzip -import logging -import os -import pickle -import random -import warnings -from datetime import datetime -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple - -import numpy as np -import ray -import ruck -from matplotlib import pyplot as plt -from ruck import R -from ruck.external.ray import ray_map -from ruck.external.ray import ray_remote_put -from ruck.external.ray import ray_remote_puts -from ruck.external.deap import select_nsga2 - -import research.util as H -from disent.dataset.wrapper import MaskedDataset -from disent.util.function import wrapped_partial -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.profiling import Timer -from disent.util.seeds import seed -from disent.util.visualize.vis_util import get_idx_traversal -from research.e01_visual_overlap.util_compute_traversal_dists import cached_compute_all_factor_dist_matrices -from research.e06_adversarial_data.util_eval_adversarial import eval_individual - - -log = logging.getLogger(__name__) - - -''' -NOTES ON MULTI-OBJECTIVE OPTIMIZATION: - https://en.wikipedia.org/wiki/Pareto_efficiency - https://en.wikipedia.org/wiki/Multi-objective_optimization - https://www.youtube.com/watch?v=SL-u_7hIqjA - - IDEAL MULTI-OBJECTIVE OPTIMIZATION - 1. generate set of pareto-optimal solutions (solutions lying along optimal boundary) -- (A posteriori methods) - - converge to pareto optimal front - - maintain as diverse a population as possible along the front (nsga-ii?) - 2. choose one from set using higher level information - - NOTE: - most multi-objective problems are just - converted into single objective functions. - - WEIGHTED SUMS - -- need to know weights - -- non-uniform in pareto-optimal solutions - -- cannot find some pareto-optimal solutions in non-convex regions - `return w0 * score0 + w1 * score1 + ...` - - ε-CONSTRAINT: constrain all but one objective - -- need to know ε vectors - -- non-uniform in pareto-optimal solutions - -- any pareto-optimal solution can be found - * EMO is a generalisation? -''' - - -# ========================================================================= # -# Ruck Helper # -# ========================================================================= # - - -def mutate_oneof(*mutate_fns): - # TODO: move this into ruck - def mutate_fn(value): - fn = random.choice(mutate_fns) - return fn(value) - return mutate_fn - - -def plt_pareto_solutions( - population, - label_fitness_0: str, - label_fitness_1: str, - title: str = None, - plot: bool = True, - chosen_idxs_f0=None, - chosen_idxs_f1=None, - random_points=None, - **fig_kw, -): - # fitness values must be of type Tuple[float, float] for this function to work! - fig, axs = H.plt_subplots(1, 1, title=title if title else 'Pareto-Optimal Solutions', **fig_kw) - # plot fitness values - xs, ys = zip(*(m.fitness for m in population)) - axs[0, 0].set_xlabel(label_fitness_0) - axs[0, 0].set_ylabel(label_fitness_1) - # plot random - if random_points is not None: - axs[0, 0].scatter(*np.array(random_points).T, c='orange') - # plot normal - axs[0, 0].scatter(xs, ys) - # plot chosen - if chosen_idxs_f0 is not None: - axs[0, 0].scatter(*np.array([population[i].fitness for i in chosen_idxs_f0]).T, c='purple') - if chosen_idxs_f1 is not None: - axs[0, 0].scatter(*np.array([population[i].fitness for i in chosen_idxs_f1]).T, c='green') - # label axes - # layout - fig.tight_layout() - # plot - if plot: - plt.show() - # done! - return fig, axs - - -def individual_ave(dataset, individual, print_=False): - if isinstance(dataset, str): - dataset = H.make_data(dataset, transform_mode='none') - # masked - sub_data = MaskedDataset(data=dataset, mask=individual.flatten()) - if print_: - print(', '.join(f'{individual.reshape(sub_data._data.factor_sizes).sum(axis=f_idx).mean():2f}' for f_idx in range(sub_data._data.num_factors))) - # make obs - ave_obs = np.zeros_like(sub_data[0], dtype='float64') - for obs in sub_data: - ave_obs += obs - return ave_obs / ave_obs.max() - - -def plot_averages(dataset_name: str, values: list, subtitle: str, title_prefix: str = None, titles=None, show: bool = False): - data = H.make_data(dataset_name, transform_mode='none') - # average individuals - ave_imgs = [individual_ave(data, v) for v in values] - col_lbls = [f'{np.sum(v)} / {np.prod(v.shape)}' for v in values] - # make plots - fig_ave_imgs, _ = H.plt_subplots_imshow( - [ave_imgs], - col_labels=col_lbls, - titles=titles, - show=show, - vmin=0.0, - vmax=1.0, - figsize=(10, 3), - title=f'{f"{title_prefix} " if title_prefix else ""}Average Datasets\n{subtitle}', - ) - return fig_ave_imgs - - -def get_spaced(array, num: int): - return [array[i] for i in get_idx_traversal(len(array), num)] - - -# ========================================================================= # -# Evaluation # -# ========================================================================= # - - -@ray.remote -def evaluate_member( - value: np.ndarray, - gt_dist_matrices: np.ndarray, - factor_sizes: Tuple[int, ...], - fitness_overlap_mode: str, - fitness_overlap_aggregate: str, - fitness_overlap_include_singles: bool, -) -> Tuple[float, float]: - overlap_score, usage_ratio = eval_individual( - individual=value, - gt_dist_matrices=gt_dist_matrices, - factor_sizes=factor_sizes, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_aggregate=fitness_overlap_aggregate, - exclude_diag=True, - increment_single=fitness_overlap_include_singles, - backend='numba', - ) - - # weight components - # assert fitness_overlap_weight >= 0 - # assert fitness_usage_weight >= 0 - # w_ovrlp = fitness_overlap_weight * overlap_score - # w_usage = fitness_usage_weight * usage_ratio - - # GOALS: minimize overlap, maximize usage - # [min, max] objective -> target - # [ 0, 1] factor_score -> 0 - # [ 0, 1] kept_ratio -> 1 - - # linear scalarization - # loss = w_ovrlp - w_usage - - # No-preference method - # -- norm(f(x) - z_ideal) - # -- preferably scale variables - # z_ovrlp = fitness_overlap_weight * (overlap_score - 0.0) - # z_usage = fitness_usage_weight * (usage_ratio - 1.0) - # loss = np.linalg.norm([z_ovrlp, z_usage], ord=2) - - # convert minimization problem into maximization - # return - loss - - if overlap_score < 0: - log.warning(f'member has invalid overlap_score: {repr(overlap_score)}') - overlap_score = 1000 # minimizing target to 0 in range [0, 1] so this is bad - if usage_ratio < 0: - log.warning(f'member has invalid usage_ratio: {repr(usage_ratio)}') - usage_ratio = -1000 # maximizing target to 1 in range [0, 1] so this is bad - - return (-overlap_score, usage_ratio) - - -# ========================================================================= # -# Type Hints # -# ========================================================================= # - - -Values = List[ray.ObjectRef] -Population = List[ruck.Member[ray.ObjectRef]] - - -# ========================================================================= # -# Evolutionary System # -# ========================================================================= # - - -class DatasetMaskModule(ruck.EaModule): - - # STATISTICS - - def get_stats_groups(self): - remote_sum = ray.remote(np.mean).remote - return { - **super().get_stats_groups(), - 'mask': ruck.StatsGroup(lambda pop: ray.get([remote_sum(m.value) for m in pop]), min=np.min, max=np.max, mean=np.mean), - } - - def get_progress_stats(self): - return ('evals', 'fit:mean', 'mask:mean') - - # POPULATION - - def gen_starting_values(self) -> Values: - return [ - ray.put(np.random.random(np.prod(self.hparams.factor_sizes)) < (0.1 + np.random.random() * 0.8)) - for _ in range(self.hparams.population_size) - ] - - def select_population(self, population: Population, offspring: Population) -> Population: - return select_nsga2(population + offspring, len(population), weights=(1.0, 1.0)) - - def evaluate_values(self, values: Values) -> List[float]: - return ray.get([self._evaluate_value_fn(v) for v in values]) - - # INITIALISE - - def __init__( - self, - dataset_name: str = 'cars3d', - dist_normalize_mode: str = 'all', - population_size: int = 128, - # fitness settings - fitness_overlap_aggregate: str = 'mean', - fitness_overlap_mode: str = 'std', - fitness_overlap_include_singles: bool = True, - # ea settings - p_mate: float = 0.5, - p_mutate: float = 0.5, - p_mutate_flip: float = 0.05, - ): - # load the dataset - gt_data = H.make_data(dataset_name) - factor_sizes = gt_data.factor_sizes - # save hyper parameters to .hparams - self.save_hyperparameters(include=['factor_sizes']) - # compute all distances - gt_dist_matrices = cached_compute_all_factor_dist_matrices(dataset_name, normalize_mode=dist_normalize_mode) - gt_dist_matrices = ray.put(gt_dist_matrices) - # get offspring function - self.generate_offspring = wrapped_partial( - R.apply_mate_and_mutate, - mate_fn=ray_remote_puts(R.mate_crossover_nd).remote, - mutate_fn=ray_remote_put(mutate_oneof( - wrapped_partial(R.mutate_flip_bits, p=p_mutate_flip), - wrapped_partial(R.mutate_flip_bit_groups, p=p_mutate_flip), - )).remote, - p_mate=p_mate, - p_mutate=p_mutate, - map_fn=ray_map # parallelize - ) - # get evaluation function - self._evaluate_value_fn = wrapped_partial( - evaluate_member.remote, - gt_dist_matrices=gt_dist_matrices, - factor_sizes=factor_sizes, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_aggregate=fitness_overlap_aggregate, - fitness_overlap_include_singles=fitness_overlap_include_singles, - ) - - -# ========================================================================= # -# RUNNER # -# ========================================================================= # - - -def run( - dataset_name: str = 'shapes3d', # xysquares_8x8_toy_s4, xcolumns_8x_toy_s1 - dist_normalize_mode: str = 'all', # all, each, none - # population - generations: int = 250, - population_size: int = 128, - # fitness settings - fitness_overlap_mode: str = 'std', - fitness_overlap_aggregate: str = 'mean', - fitness_overlap_include_singles: bool = True, - # save settings - save: bool = False, - save_prefix: str = '', - seed_: Optional[int] = None, - # plot settings - plot: bool = False, - # wandb_settings - wandb_enabled: bool = True, - wandb_init: bool = True, - wandb_project: str = 'exp-adversarial-mask', - wandb_user: str = 'n_michlo', - wandb_job_name: str = None, - wandb_tags: Optional[List[str]] = None, - wandb_finish: bool = True, -) -> Dict[str, Any]: - # save the starting time for the save path - time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') - log.info(f'Starting run at time: {time_string}') - - # get hparams - hparams = dict(dataset_name=dataset_name, dist_normalize_mode=dist_normalize_mode, generations=generations, population_size=population_size, fitness_overlap_mode=fitness_overlap_mode, fitness_overlap_aggregate=fitness_overlap_aggregate, fitness_overlap_include_singles=fitness_overlap_include_singles, save=save, save_prefix=save_prefix, seed_=seed_, plot=plot, wandb_enabled=wandb_enabled, wandb_init=wandb_init, wandb_project=wandb_project, wandb_user=wandb_user, wandb_job_name=wandb_job_name) - # name - name = f'{(save_prefix + "_" if save_prefix else "")}{dataset_name}_{generations}x{population_size}_{dist_normalize_mode}_{fitness_overlap_mode}_{fitness_overlap_aggregate}_{fitness_overlap_include_singles}' - log.info(f'- Run name is: {name}') - - # enable wandb - wandb = None - if wandb_enabled: - import wandb - # cleanup from old runs: - if wandb_init: - if wandb_finish: - try: - wandb.finish() - except: - pass - # initialize - wandb.init( - entity=wandb_user, - project=wandb_project, - name=wandb_job_name if (wandb_job_name is not None) else name, - group=None, - tags=wandb_tags, - ) - # track hparams - wandb.config.update({f'adv/{k}': v for k, v in hparams.items()}) - - # This is not completely deterministic with ray - # although the starting population will always be the same! - seed_ = seed_ if (seed_ is not None) else int(np.random.randint(1, 2**31-1)) - seed(seed_) - - # run! - with Timer('ruck:onemax'): - problem = DatasetMaskModule( - dataset_name=dataset_name, - dist_normalize_mode=dist_normalize_mode, - population_size=population_size, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_aggregate=fitness_overlap_aggregate, - fitness_overlap_include_singles=fitness_overlap_include_singles, - ) - # train - population, logbook, halloffame = ruck.Trainer(generations=generations, progress=True).fit(problem) - # retrieve stats - log.info(f'start population: {logbook[0]}') - log.info(f'end population: {logbook[-1]}') - values = [ray.get(m.value) for m in halloffame] - - # log to wandb as steps - if wandb_enabled: - for i, stats in enumerate(logbook): - stats = {f'stats/{k}': v for k, v in stats.items()} - stats['current_step'] = i - wandb.log(stats, step=i) - - # generate average images - if plot or wandb_enabled: - # plot average - fig_ave_imgs_hof = plot_averages(dataset_name, values, title_prefix='HOF', subtitle=name, show=plot) - # get individuals -- this is not ideal because not evenly spaced - idxs_chosen_f0 = get_spaced(np.argsort([m.fitness[0] for m in population])[::-1], 5) # overlap - idxs_chosen_f1 = get_spaced(np.argsort([m.fitness[1] for m in population]), 5) # usage - chosen_values_f0 = [ray.get(population[i].value) for i in idxs_chosen_f0] - chosen_values_f1 = [ray.get(population[i].value) for i in idxs_chosen_f1] - random_fitnesses = problem.evaluate_values([ray.put(np.random.random(values[0].shape) < p) for p in np.linspace(0.025, 1, num=population_size+2)[1:-1]]) - # plot averages - fig_ave_imgs_f0 = plot_averages(dataset_name, chosen_values_f0, subtitle=name, titles=[f'{population[i].fitness[0]:2f}' for i in idxs_chosen_f0], title_prefix='Overlap -', show=plot) - fig_ave_imgs_f1 = plot_averages(dataset_name, chosen_values_f1, subtitle=name, titles=[f'{population[i].fitness[1]:2f}' for i in idxs_chosen_f1], title_prefix='Usage -', show=plot) - # plot parento optimal solutions - fig_pareto_sol, axs = plt_pareto_solutions( - population, - label_fitness_0='Overlap Score', - label_fitness_1='Usage Score', - title=f'Pareto-Optimal Solutions\n{name}', - plot=plot, - chosen_idxs_f0=idxs_chosen_f0, - chosen_idxs_f1=idxs_chosen_f1, - random_points=random_fitnesses, - figsize=(7, 7), - ) - # log average - if wandb_enabled: - wandb.log({ - 'ave_images_hof': wandb.Image(fig_ave_imgs_hof), - 'ave_images_overlap': wandb.Image(fig_ave_imgs_f0), - 'ave_images_usage': wandb.Image(fig_ave_imgs_f1), - 'pareto_solutions': wandb.Image(fig_pareto_sol), - }) - - # get summary - use_elems = np.sum(values[0]) - num_elems = np.prod(values[0].shape) - use_ratio = (use_elems / num_elems) - - # log summary - if wandb_enabled: - wandb.summary['num_elements'] = num_elems - wandb.summary['used_elements'] = use_elems - wandb.summary['used_elements_ratio'] = use_ratio - for k, v in logbook[0].items(): wandb.summary[f'log:start:{k}'] = v - for k, v in logbook[-1].items(): wandb.summary[f'log:end:{k}'] = v - - # generate paths - job_name = f'{time_string}_{name}' - - # collect results - results = { - 'hparams': hparams, - 'job_name': job_name, - 'save_path': None, - 'time_string': time_string, - 'values': [ray.get(m.value) for m in population], - 'scores': [m.fitness for m in population], - # score components - 'scores_overlap': [m.fitness[0] for m in population], - 'scores_usage': [m.fitness[1] for m in population], - # history data - 'logbook_history': logbook.history, - # we don't want these because they store object refs, and - # it means we need ray to unpickle them. - # 'population': population, - # 'halloffame_members': halloffame.members, - } - - if save: - # get save path, make parent dir & save! - results['save_path'] = ensure_parent_dir_exists(ROOT_DIR, 'out/adversarial_mask', job_name, 'data.pkl.gz') - # NONE : 122943493 ~= 118M (100.%) : 103.420ms - # lvl=1 : 23566691 ~= 23M (19.1%) : 1.223s - # lvl=2 : 21913595 ~= 21M (17.8%) : 1.463s - # lvl=3 : 20688319 ~= 20M (16.8%) : 2.504s - # lvl=4 : 18325859 ~= 18M (14.9%) : 1.856s # good - # lvl=5 : 17467772 ~= 17M (14.2%) : 3.332s # good - # lvl=6 : 16594660 ~= 16M (13.5%) : 7.163s # starting to slow - # lvl=7 : 16242279 ~= 16M (13.2%) : 12.407s - # lvl=8 : 15586416 ~= 15M (12.7%) : 1m:4s # far too slow - # lvl=9 : 15023324 ~= 15M (12.2%) : 3m:11s # far too slow - log.info(f'saving data to: {results["save_path"]}') - with gzip.open(results["save_path"], 'wb', compresslevel=5) as fp: - pickle.dump(results, fp) - log.info(f'saved data to: {results["save_path"]}') - - # cleanup wandb - if wandb_enabled: - if wandb_finish: - try: - wandb.finish() - except: - pass - - # done - return results - - -# ========================================================================= # -# ENTRYPOINT # -# ========================================================================= # - - -ROOT_DIR = os.path.abspath(__file__ + '/../../..') - - -def main(): - from itertools import product - - # (3 * 2 * 2 * 5) - for (fitness_overlap_include_singles, dist_normalize_mode, fitness_overlap_aggregate, fitness_overlap_mode, dataset_name) in product( - [True, False], - ['all', 'each', 'none'], - ['gmean', 'mean'], - ['std', 'range'], - ['xysquares_8x8_toy_s2', 'cars3d', 'smallnorb', 'shapes3d', 'dsprites'], - ): - print('='*100) - print(f'[STARTING]: dataset_name={repr(dataset_name)} dist_normalize_mode={repr(dist_normalize_mode)} fitness_overlap_mode={repr(fitness_overlap_mode)} fitness_overlap_aggregate={repr(fitness_overlap_aggregate)} fitness_overlap_include_singles={repr(fitness_overlap_include_singles)}') - try: - run( - dataset_name=dataset_name, - dist_normalize_mode=dist_normalize_mode, - # fitness - fitness_overlap_aggregate=fitness_overlap_aggregate, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_include_singles=fitness_overlap_include_singles, - # population - generations=1000, - population_size=256, - seed_=42, - save=True, - save_prefix='EXP', - plot=True, - wandb_enabled=True, - wandb_project='exp-adversarial-mask', - wandb_tags=['exp_factor_dists'] - ) - except KeyboardInterrupt: - warnings.warn('Exiting early') - exit(1) - # except: - # warnings.warn(f'[FAILED]: dataset_name={repr(dataset_name)} dist_normalize_mode={repr(dist_normalize_mode)} fitness_overlap_mode={repr(fitness_overlap_mode)} fitness_overlap_aggregate={repr(fitness_overlap_aggregate)}') - print('='*100) - - -if __name__ == '__main__': - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # run - logging.basicConfig(level=logging.INFO) - ray.init(num_cpus=64) - main() - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck_dist_pairs.py b/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck_dist_pairs.py deleted file mode 100644 index ebfb7555..00000000 --- a/research/e06_adversarial_data/deprecated/run_04_gen_adversarial_ruck_dist_pairs.py +++ /dev/null @@ -1,601 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -This file generates pareto-optimal solutions to the multi-objective -optimisation problem of masking a dataset as to minimize some metric -for overlap, while maximizing the amount of data kept. - -- We solve this problem using the NSGA2 algorithm and save all the results - to disk to be loaded with `get_closest_mask` from `util_load_adversarial_mask.py` -""" - -import gzip -import logging -import os -import pickle -import random -import warnings -from datetime import datetime -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple - -import numpy as np -import psutil -import ray -import ruck -from matplotlib import pyplot as plt -from ruck import R -from ruck.external.ray import ray_map -from ruck.external.ray import ray_remote_put -from ruck.external.ray import ray_remote_puts - -from ruck.external.deap import select_nsga2 as select_nsga2_deap -# from ruck.functional import select_nsga2 as select_nsga2_ruck # should rather use this! - -import research.util as H -from disent.dataset.wrapper import MaskedDataset -from disent.util.function import wrapped_partial -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.profiling import Timer -from disent.util.seeds import seed -from disent.util.visualize.vis_util import get_idx_traversal -from research.e01_visual_overlap.util_compute_traversal_dist_pairs import cached_compute_dataset_pair_dists -from research.e06_adversarial_data.util_eval_adversarial_dist_pairs import eval_masked_dist_pairs - - -log = logging.getLogger(__name__) - - -''' -NOTES ON MULTI-OBJECTIVE OPTIMIZATION: - https://en.wikipedia.org/wiki/Pareto_efficiency - https://en.wikipedia.org/wiki/Multi-objective_optimization - https://www.youtube.com/watch?v=SL-u_7hIqjA - - IDEAL MULTI-OBJECTIVE OPTIMIZATION - 1. generate set of pareto-optimal solutions (solutions lying along optimal boundary) -- (A posteriori methods) - - converge to pareto optimal front - - maintain as diverse a population as possible along the front (nsga-ii?) - 2. choose one from set using higher level information - - NOTE: - most multi-objective problems are just - converted into single objective functions. - - WEIGHTED SUMS - -- need to know weights - -- non-uniform in pareto-optimal solutions - -- cannot find some pareto-optimal solutions in non-convex regions - `return w0 * score0 + w1 * score1 + ...` - - ε-CONSTRAINT: constrain all but one objective - -- need to know ε vectors - -- non-uniform in pareto-optimal solutions - -- any pareto-optimal solution can be found - * EMO is a generalisation? -''' - - -# ========================================================================= # -# Ruck Helper # -# ========================================================================= # - - -def mutate_oneof(*mutate_fns): - # TODO: move this into ruck - def mutate_fn(value): - fn = random.choice(mutate_fns) - return fn(value) - return mutate_fn - - -def plt_pareto_solutions( - population, - label_fitness_0: str, - label_fitness_1: str, - title: str = None, - plot: bool = True, - chosen_idxs_f0=None, - chosen_idxs_f1=None, - random_points=None, - **fig_kw, -): - # fitness values must be of type Tuple[float, float] for this function to work! - fig, axs = H.plt_subplots(1, 1, title=title if title else 'Pareto-Optimal Solutions', **fig_kw) - # plot fitness values - xs, ys = zip(*(m.fitness for m in population)) - axs[0, 0].set_xlabel(label_fitness_0) - axs[0, 0].set_ylabel(label_fitness_1) - # plot random - if random_points is not None: - axs[0, 0].scatter(*np.array(random_points).T, c='orange') - # plot normal - axs[0, 0].scatter(xs, ys) - # plot chosen - if chosen_idxs_f0 is not None: - axs[0, 0].scatter(*np.array([population[i].fitness for i in chosen_idxs_f0]).T, c='purple') - if chosen_idxs_f1 is not None: - axs[0, 0].scatter(*np.array([population[i].fitness for i in chosen_idxs_f1]).T, c='green') - # label axes - # layout - fig.tight_layout() - # plot - if plot: - plt.show() - # done! - return fig, axs - - -def individual_ave(dataset, individual, print_=False): - if isinstance(dataset, str): - dataset = H.make_data(dataset, transform_mode='none') - # masked - sub_data = MaskedDataset(data=dataset, mask=individual.flatten()) - if print_: - print(', '.join(f'{individual.reshape(sub_data._data.factor_sizes).sum(axis=f_idx).mean():2f}' for f_idx in range(sub_data._data.num_factors))) - # make obs - ave_obs = np.zeros_like(sub_data[0], dtype='float64') - for obs in sub_data: - ave_obs += obs - return ave_obs / ave_obs.max() - - -def plot_averages(dataset_name: str, values: list, subtitle: str, title_prefix: str = None, titles=None, show: bool = False): - data = H.make_data(dataset_name, transform_mode='none') - # average individuals - ave_imgs = [individual_ave(data, v) for v in values] - col_lbls = [f'{np.sum(v)} / {np.prod(v.shape)}' for v in values] - # make plots - fig_ave_imgs, _ = H.plt_subplots_imshow( - [ave_imgs], - col_labels=col_lbls, - titles=titles, - show=show, - vmin=0.0, - vmax=1.0, - figsize=(10, 3), - title=f'{f"{title_prefix} " if title_prefix else ""}Average Datasets\n{subtitle}', - ) - return fig_ave_imgs - - -def get_spaced(array, num: int): - return [array[i] for i in get_idx_traversal(len(array), num)] - - -# ========================================================================= # -# Evaluation # -# ========================================================================= # - - -@ray.remote -def evaluate_member( - value: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - fitness_overlap_mode: str, - fitness_overlap_include_singles: bool = True, -) -> Tuple[float, float]: - overlap_score, usage_ratio = eval_masked_dist_pairs( - mask=value, - pair_obs_dists=pair_obs_dists, - pair_obs_idxs=pair_obs_idxs, - fitness_mode=fitness_overlap_mode, - increment_single=fitness_overlap_include_singles, - backend='numba', - ) - - # weight components - # assert fitness_overlap_weight >= 0 - # assert fitness_usage_weight >= 0 - # w_ovrlp = fitness_overlap_weight * overlap_score - # w_usage = fitness_usage_weight * usage_ratio - - # GOALS: minimize overlap, maximize usage - # [min, max] objective -> target - # [ 0, 1] factor_score -> 0 - # [ 0, 1] kept_ratio -> 1 - - # linear scalarization - # loss = w_ovrlp - w_usage - - # No-preference method - # -- norm(f(x) - z_ideal) - # -- preferably scale variables - # z_ovrlp = fitness_overlap_weight * (overlap_score - 0.0) - # z_usage = fitness_usage_weight * (usage_ratio - 1.0) - # loss = np.linalg.norm([z_ovrlp, z_usage], ord=2) - - # convert minimization problem into maximization - # return - loss - - if overlap_score < 0: - log.warning(f'member has invalid overlap_score: {repr(overlap_score)}') - overlap_score = 1000 # minimizing target to 0 in range [0, 1] so this is bad - if usage_ratio < 0: - log.warning(f'member has invalid usage_ratio: {repr(usage_ratio)}') - usage_ratio = -1000 # maximizing target to 1 in range [0, 1] so this is bad - - return (-overlap_score, usage_ratio) - - -# ========================================================================= # -# Type Hints # -# ========================================================================= # - - -Values = List[ray.ObjectRef] -Population = List[ruck.Member[ray.ObjectRef]] - - -# ========================================================================= # -# Evolutionary System # -# ========================================================================= # - - -class DatasetDistPairMaskModule(ruck.EaModule): - - # STATISTICS - - def get_stats_groups(self): - remote_sum = ray.remote(np.mean).remote - return { - **super().get_stats_groups(), - 'mask': ruck.StatsGroup(lambda pop: ray.get([remote_sum(m.value) for m in pop]), min=np.min, max=np.max, mean=np.mean), - } - - def get_progress_stats(self): - return ('evals', 'fit:mean', 'mask:mean') - - # POPULATION - - def gen_starting_values(self) -> Values: - return [ - ray.put(np.random.random(np.prod(self.hparams.factor_sizes)) < (0.1 + np.random.random() * 0.8)) - for _ in range(self.hparams.population_size) - ] - - def select_population(self, population: Population, offspring: Population) -> Population: - return select_nsga2_deap(population + offspring, len(population)) - - def evaluate_values(self, values: Values) -> List[float]: - return ray.get([self._evaluate_value_fn(v) for v in values]) - - # INITIALISE - - def __init__( - self, - dataset_name: str = 'smallnorb', - pair_mode: str = 'nearby_scaled', # random, nearby, nearby_scaled - pairs_per_obs: int = 100, - pairs_seed: Optional[int] = None, - dists_scaled: bool = True, - # population - population_size: int = 128, - # fitness settings - fitness_overlap_mode: str = 'std', - fitness_overlap_include_singles: bool = True, - # ea settings - p_mate: float = 0.5, - p_mutate: float = 0.5, - p_mutate_flip: float = 0.05, - ): - # load the dataset - gt_data = H.make_data(dataset_name) - factor_sizes = gt_data.factor_sizes - # save hyper parameters to .hparams - self.save_hyperparameters(include=['factor_sizes']) - # compute all distances - obs_pair_idxs, obs_pair_dists = cached_compute_dataset_pair_dists(dataset_name, pair_mode=pair_mode, pairs_per_obs=pairs_per_obs, seed=pairs_seed, scaled=dists_scaled) - obs_pair_idxs = ray.put(obs_pair_idxs) - obs_pair_dists = ray.put(obs_pair_dists) - # get offspring function - self.generate_offspring = wrapped_partial( - R.apply_mate_and_mutate, - mate_fn=ray_remote_puts(R.mate_crossover_nd).remote, - mutate_fn=ray_remote_put(mutate_oneof( - wrapped_partial(R.mutate_flip_bits, p=p_mutate_flip), - wrapped_partial(R.mutate_flip_bit_groups, p=p_mutate_flip), - )).remote, - p_mate=p_mate, - p_mutate=p_mutate, - map_fn=ray_map # parallelize - ) - # get evaluation function - self._evaluate_value_fn = wrapped_partial( - evaluate_member.remote, - pair_obs_dists=obs_pair_dists, - pair_obs_idxs=obs_pair_idxs, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_include_singles=fitness_overlap_include_singles, - ) - - -# ========================================================================= # -# RUNNER # -# ========================================================================= # - - -def run( - dataset_name: str = 'shapes3d', # xysquares_8x8_toy_s4, xcolumns_8x_toy_s1 - pair_mode: str = 'nearby_scaled', - pairs_per_obs: int = 64, - dists_scaled: bool = True, - # population - generations: int = 250, - population_size: int = 128, - # fitness settings - fitness_overlap_mode: str = 'std', - fitness_overlap_include_singles: bool = True, - # save settings - save: bool = False, - save_prefix: str = '', - seed_: Optional[int] = None, - # plot settings - plot: bool = False, - # wandb_settings - wandb_enabled: bool = True, - wandb_init: bool = True, - wandb_project: str = 'exp-adversarial-mask', - wandb_user: str = 'n_michlo', - wandb_job_name: str = None, - wandb_tags: Optional[List[str]] = None, - wandb_finish: bool = True, -) -> Dict[str, Any]: - # save the starting time for the save path - time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') - log.info(f'Starting run at time: {time_string}') - - # get hparams - hparams = dict(dataset_name=dataset_name, pair_mode=pair_mode, pairs_per_obs=pairs_per_obs, dists_scaled=dists_scaled, generations=generations, population_size=population_size, fitness_overlap_mode=fitness_overlap_mode, fitness_overlap_include_singles=fitness_overlap_include_singles, save=save, save_prefix=save_prefix, seed_=seed_, plot=plot, wandb_enabled=wandb_enabled, wandb_init=wandb_init, wandb_project=wandb_project, wandb_user=wandb_user, wandb_job_name=wandb_job_name, wandb_tags=wandb_tags, wandb_finish=wandb_finish) - # name - name = f'{(save_prefix + "_" if save_prefix else "")}{dataset_name}_{generations}x{population_size}_{pair_mode}_{pairs_per_obs}_{dists_scaled}_{fitness_overlap_mode}_{fitness_overlap_include_singles}' - log.info(f'- Run name is: {name}') - - # enable wandb - wandb = None - if wandb_enabled: - import wandb - # cleanup from old runs: - if wandb_init: - if wandb_finish: - try: - wandb.finish() - except: - pass - # initialize - wandb.init( - entity=wandb_user, - project=wandb_project, - name=wandb_job_name if (wandb_job_name is not None) else name, - group=None, - tags=wandb_tags, - ) - # track hparams - wandb.config.update({f'adv/{k}': v for k, v in hparams.items()}) - - # This is not completely deterministic with ray - # although the starting population will always be the same! - seed_ = seed_ if (seed_ is not None) else int(np.random.randint(1, 2**31-1)) - seed(seed_) - - # run! - with Timer('ruck:onemax'): - problem = DatasetDistPairMaskModule( - dataset_name=dataset_name, - pair_mode=pair_mode, - pairs_per_obs=pairs_per_obs, - # pairs_seed=pairs_seed, - dists_scaled=dists_scaled, - # population - population_size=population_size, - # fitness settings - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_include_singles=fitness_overlap_include_singles, - # ea settings - # p_mate=p_mate, - # p_mutate=p_mutate, - # p_mutate_flip=p_mutate_flip, - ) - # train - population, logbook, halloffame = ruck.Trainer(generations=generations, progress=True).fit(problem) - # retrieve stats - log.info(f'start population: {logbook[0]}') - log.info(f'end population: {logbook[-1]}') - values = [ray.get(m.value) for m in halloffame] - - # log to wandb as steps - if wandb_enabled: - for i, stats in enumerate(logbook): - stats = {f'stats/{k}': v for k, v in stats.items()} - stats['current_step'] = i - wandb.log(stats, step=i) - - # generate average images - if plot or wandb_enabled: - # plot average - fig_ave_imgs_hof = plot_averages(dataset_name, values, title_prefix='HOF', subtitle=name, show=plot) - # get individuals -- this is not ideal because not evenly spaced - idxs_chosen_f0 = get_spaced(np.argsort([m.fitness[0] for m in population])[::-1], 5) # overlap - idxs_chosen_f1 = get_spaced(np.argsort([m.fitness[1] for m in population]), 5) # usage - chosen_values_f0 = [ray.get(population[i].value) for i in idxs_chosen_f0] - chosen_values_f1 = [ray.get(population[i].value) for i in idxs_chosen_f1] - random_fitnesses = problem.evaluate_values([ray.put(np.random.random(values[0].shape) < p) for p in np.linspace(0.025, 1, num=population_size+2)[1:-1]]) - # plot averages - fig_ave_imgs_f0 = plot_averages(dataset_name, chosen_values_f0, subtitle=name, titles=[f'{population[i].fitness[0]:2f}' for i in idxs_chosen_f0], title_prefix='Overlap -', show=plot) - fig_ave_imgs_f1 = plot_averages(dataset_name, chosen_values_f1, subtitle=name, titles=[f'{population[i].fitness[1]:2f}' for i in idxs_chosen_f1], title_prefix='Usage -', show=plot) - # plot parento optimal solutions - fig_pareto_sol, axs = plt_pareto_solutions( - population, - label_fitness_0='Overlap Score', - label_fitness_1='Usage Score', - title=f'Pareto-Optimal Solutions\n{name}', - plot=plot, - chosen_idxs_f0=idxs_chosen_f0, - chosen_idxs_f1=idxs_chosen_f1, - random_points=random_fitnesses, - figsize=(7, 7), - ) - # plot factor usage ratios - # TODO: PLOT 2D matrix of all permutations of factors aggregated - # log average - if wandb_enabled: - wandb.log({ - 'ave_images_hof': wandb.Image(fig_ave_imgs_hof), - 'ave_images_overlap': wandb.Image(fig_ave_imgs_f0), - 'ave_images_usage': wandb.Image(fig_ave_imgs_f1), - 'pareto_solutions': wandb.Image(fig_pareto_sol), - }) - - # get summary - use_elems = np.sum(values[0]) - num_elems = np.prod(values[0].shape) - use_ratio = (use_elems / num_elems) - - # log summary - if wandb_enabled: - wandb.summary['num_elements'] = num_elems - wandb.summary['used_elements'] = use_elems - wandb.summary['used_elements_ratio'] = use_ratio - for k, v in logbook[0].items(): wandb.summary[f'log:start:{k}'] = v - for k, v in logbook[-1].items(): wandb.summary[f'log:end:{k}'] = v - - # generate paths - job_name = f'{time_string}_{name}' - - # collect results - results = { - 'hparams': hparams, - 'job_name': job_name, - 'save_path': None, - 'time_string': time_string, - 'values': [ray.get(m.value) for m in population], - 'scores': [m.fitness for m in population], - # score components - 'scores_overlap': [m.fitness[0] for m in population], - 'scores_usage': [m.fitness[1] for m in population], - # history data - 'logbook_history': logbook.history, - # we don't want these because they store object refs, and - # it means we need ray to unpickle them. - # 'population': population, - # 'halloffame_members': halloffame.members, - } - - if save: - # get save path, make parent dir & save! - results['save_path'] = ensure_parent_dir_exists(ROOT_DIR, 'out/adversarial_mask', job_name, 'data.pkl.gz') - # NONE : 122943493 ~= 118M (100.%) : 103.420ms - # lvl=1 : 23566691 ~= 23M (19.1%) : 1.223s - # lvl=2 : 21913595 ~= 21M (17.8%) : 1.463s - # lvl=3 : 20688319 ~= 20M (16.8%) : 2.504s - # lvl=4 : 18325859 ~= 18M (14.9%) : 1.856s # good - # lvl=5 : 17467772 ~= 17M (14.2%) : 3.332s # good - # lvl=6 : 16594660 ~= 16M (13.5%) : 7.163s # starting to slow - # lvl=7 : 16242279 ~= 16M (13.2%) : 12.407s - # lvl=8 : 15586416 ~= 15M (12.7%) : 1m:4s # far too slow - # lvl=9 : 15023324 ~= 15M (12.2%) : 3m:11s # far too slow - log.info(f'saving data to: {results["save_path"]}') - with gzip.open(results["save_path"], 'wb', compresslevel=5) as fp: - pickle.dump(results, fp) - log.info(f'saved data to: {results["save_path"]}') - - # cleanup wandb - if wandb_enabled: - if wandb_finish: - try: - wandb.finish() - except: - pass - - # done - return results - - -# ========================================================================= # -# ENTRYPOINT # -# ========================================================================= # - - -ROOT_DIR = os.path.abspath(__file__ + '/../../..') - - -def main(): - from itertools import product - - # (2*1 * 3*1*2 * 5) = 60 - for i, (fitness_overlap_include_singles, dists_scaled, pair_mode, pairs_per_obs, fitness_overlap_mode, dataset_name) in enumerate(product( - [True, False], - [True], # [True, False] - ['nearby_scaled', 'nearby', 'random'], - [256], # [64, 16, 256] - ['std', 'range'], - ['xysquares_8x8_toy_s2', 'cars3d', 'smallnorb', 'shapes3d', 'dsprites'], # ['xysquares_8x8_toy_s2'] - )): - print('='*100) - print(f'[STARTING]: i={i} dataset_name={repr(dataset_name)} pair_mode={repr(pair_mode)} pairs_per_obs={repr(pairs_per_obs)} dists_scaled={repr(dists_scaled)} fitness_overlap_mode={repr(fitness_overlap_mode)} fitness_overlap_include_singles={repr(fitness_overlap_include_singles)}') - try: - run( - dataset_name=dataset_name, - pair_mode=pair_mode, - pairs_per_obs=pairs_per_obs, - dists_scaled=dists_scaled, - fitness_overlap_mode=fitness_overlap_mode, - fitness_overlap_include_singles=fitness_overlap_include_singles, - # population - generations=1000, # 1000 - population_size=384, - seed_=42, - save=True, - save_prefix='DISTS-SCALED', - plot=True, - wandb_enabled=True, - wandb_project='exp-adversarial-mask', - wandb_tags=['exp_pair_dists'] - ) - except KeyboardInterrupt: - warnings.warn('Exiting early') - exit(1) - except: - warnings.warn(f'[FAILED] i={i}') - print('='*100) - - -if __name__ == '__main__': - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # run - logging.basicConfig(level=logging.INFO) - ray.init(num_cpus=psutil.cpu_count(logical=False)) - main() - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e06_adversarial_data/deprecated/submit_02_train_adversarial_data.sh b/research/e06_adversarial_data/deprecated/submit_02_train_adversarial_data.sh deleted file mode 100644 index ef6cf701..00000000 --- a/research/e06_adversarial_data/deprecated/submit_02_train_adversarial_data.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-06__adversarial-modified-data" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(dirname "$(realpath -s "$0")")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# 1 * (4 * 2 * 2) = 16 -local_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_griffin' \ - run_location='griffin' \ - \ - run_length=short \ - metrics=fast \ - \ - framework.beta=0.001,0.00316,0.01,0.000316 \ - framework=betavae,adavae_os \ - model.z_size=25 \ - \ - dataset=X--adv-dsprites--WARNING,X--adv-shapes3d--WARNING \ - sampling=default__bb # \ - # \ - # hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97"' # we don't want to sweep over these - -# 2 * (8 * 2 * 4) = 128 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_beta' \ - \ - run_length=short \ - metrics=fast \ - \ - framework.beta=0.000316,0.001,0.00316,0.01,0.0316,0.1,0.316,1.0 \ - framework=betavae,adavae_os \ - model.z_size=25 \ - \ - dataset=dsprites,shapes3d,cars3d,smallnorb \ - sampling=default__bb \ - \ - hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97,mscluster99"' # we don't want to sweep over these diff --git a/research/e06_adversarial_data/deprecated/submit_04_train_dsprites_imagenet.sh b/research/e06_adversarial_data/deprecated/submit_04_train_dsprites_imagenet.sh deleted file mode 100644 index 5737db33..00000000 --- a/research/e06_adversarial_data/deprecated/submit_04_train_dsprites_imagenet.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-06__dsprites-imagenet" -export PARTITION="stampede" -export PARALLELISM=36 - -# source the helper file -source "$(dirname "$(dirname "$(dirname "$(realpath -s "$0")")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# (3*2*2*11) = 132 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_dsprites_imagenet' \ - \ - run_callbacks=vis \ - run_length=medium \ - metrics=fast \ - \ - model.z_size=9,16 \ - framework.beta=0.0316,0.01,0.1 \ - framework=adavae_os,betavae \ - \ - dataset=dsprites,X--dsprites-imagenet-bg-20,X--dsprites-imagenet-bg-40,X--dsprites-imagenet-bg-60,X--dsprites-imagenet-bg-80,X--dsprites-imagenet-bg-100,X--dsprites-imagenet-fg-20,X--dsprites-imagenet-fg-40,X--dsprites-imagenet-fg-60,X--dsprites-imagenet-fg-80,X--dsprites-imagenet-fg-100 \ - sampling=default__bb \ - \ - hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97,mscluster99"' # we don't want to sweep over these diff --git a/research/e06_adversarial_data/deprecated/submit_04_train_masked_data.sh b/research/e06_adversarial_data/deprecated/submit_04_train_masked_data.sh deleted file mode 100644 index 4a5f57dc..00000000 --- a/research/e06_adversarial_data/deprecated/submit_04_train_masked_data.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-06__masked-datasets" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(dirname "$(realpath -s "$0")")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# 3 * (12 * 2 * 2) = 144 -#submit_sweep \ -# +DUMMY.repeat=1,2,3 \ -# +EXTRA.tags='sweep_01' \ -# \ -# run_length=medium \ -# \ -# framework.beta=0.001 \ -# framework=betavae,adavae_os \ -# model.z_size=9 \ -# \ -# dataset=X--mask-adv-f-dsprites,X--mask-ran-dsprites,dsprites,X--mask-adv-f-shapes3d,X--mask-ran-shapes3d,shapes3d,X--mask-adv-f-smallnorb,X--mask-ran-smallnorb,smallnorb,X--mask-adv-f-cars3d,X--mask-ran-cars3d,cars3d \ -# sampling=random \ -# \ -# hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97"' # we don't want to sweep over these - -# TODO: beta needs to be tuned! -# 3 * (12*3*2 = 72) = 216 -submit_sweep \ - +DUMMY.repeat=1,2,3 \ - +EXTRA.tags='sweep_usage_ratio' \ - \ - run_callbacks=vis \ - run_length=short \ - metrics=all \ - \ - framework.beta=0.001 \ - framework=betavae,adavae_os \ - model.z_size=25 \ - framework.optional.usage_ratio=0.5,0.2,0.05 \ - \ - dataset=X--mask-adv-f-dsprites,X--mask-ran-dsprites,dsprites,X--mask-adv-f-shapes3d,X--mask-ran-shapes3d,shapes3d,X--mask-adv-f-smallnorb,X--mask-ran-smallnorb,smallnorb,X--mask-adv-f-cars3d,X--mask-ran-cars3d,cars3d \ - sampling=random \ - \ - hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97,mscluster99"' # we don't want to sweep over these diff --git a/research/e06_adversarial_data/deprecated/submit_04_train_masked_data_dist_pairs.sh b/research/e06_adversarial_data/deprecated/submit_04_train_masked_data_dist_pairs.sh deleted file mode 100644 index 59d9ad11..00000000 --- a/research/e06_adversarial_data/deprecated/submit_04_train_masked_data_dist_pairs.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-06__masked-datasets-dist-pairs" -export PARTITION="stampede" -export PARALLELISM=36 - -# source the helper file -source "$(dirname "$(dirname "$(dirname "$(realpath -s "$0")")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TODO: update this script -echo UPDATE THIS SCRIPT -exit 1 - -# (3*2*3*12 = 72) = 216 -# TODO: z_size needs tuning -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='sweep_dist_pairs_usage_ratio' \ - \ - run_callbacks=vis \ - run_length=short \ - metrics=all \ - \ - framework.beta=0.0316,0.01,0.1 \ - framework=betavae,adavae_os \ - model.z_size=16 \ - framework.optional.usage_ratio=0.5,0.2,0.05 \ - \ - dataset=X--mask-adv-r-dsprites,X--mask-ran-dsprites,dsprites,X--mask-adv-r-shapes3d,X--mask-ran-shapes3d,shapes3d,X--mask-adv-r-smallnorb,X--mask-ran-smallnorb,smallnorb,X--mask-adv-r-cars3d,X--mask-ran-cars3d,cars3d \ - sampling=random \ - \ - hydra.launcher.exclude='"mscluster93,mscluster94,mscluster97,mscluster99"' # we don't want to sweep over these diff --git a/research/e06_adversarial_data/run_02_adv_dataset_approx.sh b/research/e06_adversarial_data/run_02_adv_dataset_approx.sh deleted file mode 100644 index 75dd8b44..00000000 --- a/research/e06_adversarial_data/run_02_adv_dataset_approx.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -# get the path to the script -PARENT_DIR="$(dirname "$(realpath -s "$0")")" -ROOT_DIR="$(dirname "$(dirname "$PARENT_DIR")")" - -# maybe lower lr or increase batch size? -#PYTHONPATH="$ROOT_DIR" python3 "$PARENT_DIR/run_02_gen_adversarial_dataset_approx.py" \ -# -m \ -# adv_system.sampler_name=close_p_random_n,same_k1_close \ -# adv_system.adversarial_mode=self,invert_margin_0.005 \ -# adv_system.dataset_name=dsprites,shapes3d,cars3d,smallnorb - -#PYTHONPATH="$ROOT_DIR" python3 "$PARENT_DIR/run_02_gen_adversarial_dataset_approx.py" \ -# -m \ -# settings.dataset.batch_size=32,256 \ -# adv_system.loss_out_of_bounds_weight=0.0,1.0 \ -# \ -# adv_system.sampler_name=close_p_random_n \ -# adv_system.adversarial_mode=invert_margin_0.05,invert_margin_0.005,invert_margin_0.0005 \ -# adv_system.dataset_name=smallnorb - -PYTHONPATH="$ROOT_DIR" python3 "$PARENT_DIR/run_02_gen_adversarial_dataset_approx.py" \ - -m "$@" \ - \ - +meta.tag='unbounded_manhat' \ - settings.job.name_prefix=MANHAT \ - \ - adv_system.sampler_name=same_k_close,random_swap_manhattan,close_p_random_n \ - adv_system.samples_sort_mode=swap,sort_reverse,none,sort_inorder \ - \ - adv_system.adversarial_mode=triplet_unbounded \ - adv_system.dataset_name=smallnorb \ - \ - trainer.max_steps=7500 \ - trainer.max_epochs=7500 \ - \ - adv_system.optimizer_lr=5e-3 \ - settings.exp.show_every_n_steps=500 \ - \ - settings.dataset.batch_size=128 diff --git a/research/e06_adversarial_data/run_02_gen_adversarial_dataset_approx.py b/research/e06_adversarial_data/run_02_gen_adversarial_dataset_approx.py deleted file mode 100644 index 0caf72ec..00000000 --- a/research/e06_adversarial_data/run_02_gen_adversarial_dataset_approx.py +++ /dev/null @@ -1,620 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -Generate an adversarial dataset by approximating the difference between -the dataset and the target adversarial images using a model. - adv = obs + diff(obs) -""" - -import logging -import os -from datetime import datetime -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple - -import hydra -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -import wandb -from omegaconf import OmegaConf -from torch.utils.data import DataLoader - -import research.util as H -from disent import registry -from disent.dataset import DisentDataset -from disent.dataset.sampling import BaseDisentSampler -from disent.dataset.util.hdf5 import H5Builder -from disent.model import AutoEncoder -from disent.nn.activations import Swish -from disent.nn.modules import DisentModule -from disent.nn.weights import init_model_weights -from disent.util import to_numpy -from disent.util.function import wrapped_partial -from disent.util.inout.paths import ensure_parent_dir_exists -from disent.util.lightning.callbacks import BaseCallbackPeriodic -from disent.util.lightning.callbacks import LoggerProgressCallback -from disent.util.lightning.logger_util import wb_has_logger -from disent.util.lightning.logger_util import wb_log_metrics -from disent.util.seeds import seed -from disent.util.seeds import TempNumpySeed -from disent.util.strings.fmt import bytes_to_human -from disent.util.strings.fmt import make_box_str -from disent.util.visualize.vis_util import make_image_grid -from experiment.run import hydra_get_gpus -from experiment.run import hydra_get_callbacks -from experiment.run import hydra_make_logger -from experiment.util.hydra_utils import make_non_strict -from experiment.util.run_utils import log_error_and_exit -from research.e06_adversarial_data.util_gen_adversarial_dataset import adversarial_loss -from research.e06_adversarial_data.util_gen_adversarial_dataset import make_adversarial_sampler -from research.e06_adversarial_data.util_gen_adversarial_dataset import sort_samples - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Dataset Mask # -# ========================================================================= # - -@torch.no_grad() -def _sample_stacked_batch(dataset: DisentDataset) -> torch.Tensor: - batch = next(iter(DataLoader(dataset, batch_size=1024, num_workers=0, shuffle=True))) - batch = torch.cat(batch['x_targ'], dim=0) - return batch - -@torch.no_grad() -def gen_approx_dataset_mask(dataset: DisentDataset, model_mask_mode: Optional[str]) -> Optional[torch.Tensor]: - if model_mask_mode in ('none', None): - mask = None - elif model_mask_mode == 'diff': - batch = _sample_stacked_batch(dataset) - mask = ~torch.all(batch[1:] == batch[0:1], dim=0) - elif model_mask_mode == 'std': - batch = _sample_stacked_batch(dataset) - mask = torch.std(batch, dim=0) - m, M = torch.min(mask), torch.max(mask) - mask = (mask - m) / (M - m) - else: - raise KeyError(f'invalid `model_mask_mode`: {repr(model_mask_mode)}') - # done - return mask - - -# ========================================================================= # -# adversarial dataset generator # -# ========================================================================= # - - -class AeModel(AutoEncoder): - def forward(self, x): - return self.decode(self.encode(x)) - - -def make_delta_model(model_type: str, x_shape: Tuple[int, ...]): - C, H, W = x_shape - # get model - if model_type.startswith('ae_'): - return AeModel( - encoder=registry.MODELS[f'encoder_{model_type[len("ae_"):]}'](x_shape=x_shape, z_size=64, z_multiplier=1), - decoder=registry.MODELS[f'decoder_{model_type[len("ae_"):]}'](x_shape=x_shape, z_size=64, z_multiplier=1), - ) - elif model_type == 'fcn_small': - return torch.nn.Sequential( - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=C, out_channels=5, kernel_size=3), Swish(), - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=5, out_channels=7, kernel_size=3), Swish(), - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=7, out_channels=9, kernel_size=3), Swish(), - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=9, out_channels=7, kernel_size=3), Swish(), - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=7, out_channels=5, kernel_size=3), Swish(), - torch.nn.ReflectionPad2d(1), torch.nn.Conv2d(in_channels=5, out_channels=C, kernel_size=3), - ) - else: - raise KeyError(f'invalid model type: {repr(model_type)}') - - -class AdversarialAugmentModel(DisentModule): - - def __init__(self, model_type: str, x_shape=(3, 64, 64), mask=None, meta: dict = None): - super().__init__() - # make layers - self.delta_model = make_delta_model(model_type=model_type, x_shape=x_shape) - self.meta = meta if meta else {} - # mask - if mask is not None: - self.register_buffer('mask', mask[None, ...]) - assert self.mask.ndim == 4 # (1, C, H, W) - - def forward(self, x): - assert x.ndim == 4 - # compute - if hasattr(self, 'mask'): - return x + self.delta_model(x) * self.mask - else: - return x + self.delta_model(x) - - -# ========================================================================= # -# adversarial dataset generator # -# ========================================================================= # - - -class AdversarialModel(pl.LightningModule): - - def __init__( - self, - # optimizer options - optimizer_name: str = 'sgd', - optimizer_lr: float = 5e-2, - optimizer_kwargs: Optional[dict] = None, - # dataset config options - dataset_name: str = 'cars3d', - dataset_num_workers: int = min(os.cpu_count(), 16), - dataset_batch_size: int = 1024, # approx - data_root: str = 'data/dataset', - data_load_into_memory: bool = False, - # adversarial loss options - adversarial_mode: str = 'self', - adversarial_swapped: bool = False, - adversarial_masking: bool = False, - adversarial_top_k: Optional[int] = None, - pixel_loss_mode: str = 'mse', - # loss extras - loss_adversarial_weight: Optional[float] = 1.0, - loss_same_stats_weight: Optional[float] = 0.0, - loss_similarity_weight: Optional[float] = 0.0, - loss_out_of_bounds_weight: Optional[float] = 0.0, - # sampling config - sampler_name: str = 'close_far', - samples_sort_mode: str = 'none', - # model settings - model_type: str = 'ae_linear', - model_mask_mode: Optional[str] = 'none', - model_weight_init: str = 'xavier_normal', - # logging settings - logging_scale_imgs: bool = False, - # log_wb_stats_table: bool = True, - ): - super().__init__() - # modify hparams - if optimizer_kwargs is None: - optimizer_kwargs = {} # value used by save_hyperparameters - # save hparams - self.save_hyperparameters() - # variables - self.dataset: DisentDataset = None - self.sampler: BaseDisentSampler = None - self.model: DisentModule = None - - # ================================== # - # setup # - # ================================== # - - def prepare_data(self) -> None: - # create dataset - self.dataset = H.make_dataset( - self.hparams.dataset_name, - load_into_memory=self.hparams.data_load_into_memory, - load_memory_dtype=torch.float32, - data_root=self.hparams.data_root, - sampler=make_adversarial_sampler(self.hparams.sampler_name), - ) - # make the model - self.model = AdversarialAugmentModel( - model_type=self.hparams.model_type, - x_shape=(self.dataset.gt_data.img_channels, 64, 64), - mask=gen_approx_dataset_mask(dataset=self.dataset, model_mask_mode=self.hparams.model_mask_mode), - # if we save the model we can restore things! - meta=dict( - dataset_name=self.hparams.dataset_name, - dataset_factor_sizes=self.dataset.gt_data.factor_sizes, - dataset_factor_names=self.dataset.gt_data.factor_names, - sampler_name=self.hparams.sampler_name, - hparams=dict(self.hparams) - ), - ) - # initialize model - self.model = init_model_weights(self.model, mode=self.hparams.model_weight_init) - - def train_dataloader(self): - return DataLoader( - self.dataset, - batch_size=self.hparams.dataset_batch_size, - num_workers=self.hparams.dataset_num_workers, - shuffle=True, - ) - - def configure_optimizers(self): - return H.make_optimizer( - self.model, - name=self.hparams.optimizer_name, - lr=self.hparams.optimizer_lr, - **self.hparams.optimizer_kwargs, - ) - - # ================================== # - # train step # - # ================================== # - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - (a_x, p_x, n_x) = batch['x_targ'] - # sort inputs - a_x, p_x, n_x = sort_samples(a_x, p_x, n_x, sort_mode=self.hparams.samples_sort_mode, pixel_loss_mode=self.hparams.pixel_loss_mode) - # feed forward - a_y = self.model(a_x) - p_y = self.model(p_x) - n_y = self.model(n_x) - # compute loss - loss_adv = 0 - if (self.hparams.loss_adversarial_weight is not None) and (self.hparams.loss_adversarial_weight > 0): - loss_adv, loss_adv_stats = adversarial_loss( - ys=(a_y, p_y, n_y), - xs=(a_x, p_x, n_x), - adversarial_mode=self.hparams.adversarial_mode, - adversarial_swapped=self.hparams.adversarial_swapped, - adversarial_masking=self.hparams.adversarial_masking, - adversarial_top_k=self.hparams.adversarial_top_k, - pixel_loss_mode=self.hparams.pixel_loss_mode, - return_stats=True, - ) - loss_adv *= self.hparams.loss_adversarial_weight - self.log_dict(loss_adv_stats) - # additional loss components - # - keep stats the same - loss_stats = 0 - if (self.hparams.loss_same_stats_weight is not None) and (self.hparams.loss_same_stats_weight > 0): - loss_stats += (self.hparams.loss_same_stats_weight/3) * (( - F.mse_loss(a_y.mean(dim=[-3, -2, -1]), a_x.mean(dim=[-3, -2, -1]), reduction='mean') + - F.mse_loss(p_y.mean(dim=[-3, -2, -1]), p_x.mean(dim=[-3, -2, -1]), reduction='mean') + - F.mse_loss(n_y.mean(dim=[-3, -2, -1]), n_x.mean(dim=[-3, -2, -1]), reduction='mean') - ) + ( - F.mse_loss(a_y.std(dim=[-3, -2, -1]), a_x.std(dim=[-3, -2, -1]), reduction='mean') + - F.mse_loss(p_y.std(dim=[-3, -2, -1]), p_x.std(dim=[-3, -2, -1]), reduction='mean') + - F.mse_loss(n_y.std(dim=[-3, -2, -1]), n_x.std(dim=[-3, -2, -1]), reduction='mean') - )) - # - try keep similar to inputs - loss_sim = 0 - if (self.hparams.loss_similarity_weight is not None) and (self.hparams.loss_similarity_weight > 0): - loss_sim = (self.hparams.loss_similarity_weight / 3) * ( - F.mse_loss(a_y, a_x, reduction='mean') + - F.mse_loss(p_y, p_x, reduction='mean') + - F.mse_loss(n_y, n_x, reduction='mean') - ) - # - regularize if out of bounds - loss_out = 0 - if (self.hparams.loss_out_of_bounds_weight is not None) and (self.hparams.loss_out_of_bounds_weight > 0): - zeros = torch.zeros_like(a_y) - loss_out = (self.hparams.loss_out_of_bounds_weight / 6) * ( - torch.where(a_y < 0, -a_y, zeros).mean() + torch.where(a_y > 1, a_y-1, zeros).mean() + - torch.where(p_y < 0, -p_y, zeros).mean() + torch.where(p_y > 1, p_y-1, zeros).mean() + - torch.where(n_y < 0, -n_y, zeros).mean() + torch.where(n_y > 1, n_y-1, zeros).mean() - ) - # final loss - loss = loss_adv + loss_sim + loss_out - # log everything - self.log_dict({ - 'loss': loss, - 'loss_stats': loss_stats, - 'loss_adv': loss_adv, - 'loss_out': loss_out, - 'loss_sim': loss_sim, - }, prog_bar=True) - # done! - return loss - - # ================================== # - # dataset # - # ================================== # - - @torch.no_grad() - def batch_to_adversarial_imgs(self, batch: torch.Tensor, m=0, M=1, mode='uint8') -> np.ndarray: - batch = batch.to(device=self.device, dtype=torch.float32) - batch = self.model(batch) - batch = (batch - m) / (M - m) - if mode == 'uint8': return H.to_imgs(batch).numpy() - elif mode == 'float32': return torch.moveaxis(batch, -3, -1).to(torch.float32).cpu().numpy() - elif mode == 'float16': return torch.moveaxis(batch, -3, -1).to(torch.float16).cpu().numpy() - else: raise KeyError(f'invalid output mode: {repr(mode)}') - - def make_train_periodic_callbacks(self, cfg) -> Sequence[BaseCallbackPeriodic]: - - # dataset transform helper - @TempNumpySeed(42) - @torch.no_grad() - def make_scale_uint8_transform(): - # get scaling values - if self.hparams.logging_scale_imgs: - samples = self.dataset.dataset_sample_batch(num_samples=128, mode='raw').to(torch.float32) - samples = self.model(samples.to(self.device)).cpu() - m, M = float(torch.min(samples)), float(torch.max(samples)) - else: - m, M = 0, 1 - return lambda x: self.batch_to_adversarial_imgs(x[None, ...], m=m, M=M)[0] - - # show image callback - class _BaseDatasetCallback(BaseCallbackPeriodic): - @TempNumpySeed(777) - @torch.no_grad() - def do_step(this, trainer: pl.Trainer, system: AdversarialModel): - if not wb_has_logger(trainer.logger): - log.warning(f'no wandb logger found, skipping visualisation: {system.__class__.__name__}') - return - if system.dataset is None: - log.warning(f'dataset not initialized, skipping visualisation: {system.__class__.__name__}') - return - log.info(f'visualising: {this.__class__.__name__}') - try: - this._do_step(trainer, system) - except: - log.error('Failed to do visualise callback step!', exc_info=True) - - # override this - def _do_step(this, trainer: pl.Trainer, system: AdversarialModel): - raise NotImplementedError - - # show image callback - class ImShowCallback(_BaseDatasetCallback): - def _do_step(this, trainer: pl.Trainer, system: AdversarialModel): - # make dataset with required transform - # -- this is inefficient for multiple subclasses of this class, we need to recompute the transform each time - dataset = system.dataset.shallow_copy(transform=make_scale_uint8_transform()) - # get images & traversal - image = make_image_grid(dataset.dataset_sample_batch(num_samples=16, mode='input')) - wandb_image, wandb_animation = H.visualize_dataset_traversal(dataset, data_mode='input', output_wandb=True) - # log images to WANDB - wb_log_metrics(trainer.logger, { - 'random_images': wandb.Image(image), - 'traversal_image': wandb_image, - 'traversal_animation': wandb_animation, - }) - - # factor distances callback - class DistsPlotCallback(_BaseDatasetCallback): - def _do_step(this, trainer: pl.Trainer, system: AdversarialModel): - from disent.util.lightning.callbacks._callbacks_vae import compute_factor_distances, plt_factor_distances - - # make distances function - def dists_fn(xs_a, xs_b): - dists = H.pairwise_loss(xs_a, xs_b, mode=system.hparams.pixel_loss_mode, mean_dtype=torch.float32, mask=None) - return [dists] - - def transform_batch(batch): - return system.model(batch.to(device=system.device)) - - # compute various distances matrices for each factor - dists_names, f_grid = compute_factor_distances( - dataset=system.dataset, - dists_fn=dists_fn, - dists_names=['dists'], - traversal_repeats=100, - batch_size=system.hparams.dataset_batch_size, - include_gt_factor_dists=True, - transform_batch=transform_batch, - seed=777, - data_mode='input', - ) - # plot these results - fig, axs = plt_factor_distances( - gt_data=system.dataset.gt_data, - f_grid=f_grid, - dists_names=dists_names, - title=f'{system.hparams.model_type.capitalize()}: {system.hparams.dataset_name.capitalize()} Distances', - plt_block_size=1.25, - plt_transpose=True, - plt_cmap='Blues', - ) - # recolour dists axis - for ax in axs[-1, :]: - ax.images[0].set_cmap('Reds') - # generate image & close matplotlib instace - from matplotlib import pyplot as plt - img = wandb.Image(fig) - plt.close() - # log the plot to wandb - if True: - wb_log_metrics(trainer.logger, { - 'factor_distances': img - }) - - # show stats callback - class StatsShowCallback(_BaseDatasetCallback): - def _do_step(this, trainer: pl.Trainer, system: AdversarialModel): - # make dataset with required transform - # -- this is inefficient for multiple subclasses of this class, we need to recompute the transform each time - dataset = system.dataset.shallow_copy(transform=make_scale_uint8_transform()) - # get batches - batch, factors = dataset.dataset_sample_batch_with_factors(num_samples=512, mode='input') - batch = batch.to(torch.float32) - a_idx = torch.randint(0, len(batch), size=[4*len(batch)]) - b_idx = torch.randint(0, len(batch), size=[4*len(batch)]) - mask = (a_idx != b_idx) - # TODO: check that this is deterministic - # compute distances - deltas = to_numpy(H.pairwise_overlap(batch[a_idx[mask]], batch[b_idx[mask]], mode='mse')) - fdists = to_numpy(torch.abs(factors[a_idx[mask]] - factors[b_idx[mask]]).sum(dim=-1)) - sdists = to_numpy((torch.abs(factors[a_idx[mask]] - factors[b_idx[mask]]) / to_numpy(dataset.gt_data.factor_sizes)[None, :]).sum(dim=-1)) - # log to wandb - from matplotlib import pyplot as plt - plt.scatter(fdists, deltas); img_fdists = wandb.Image(plt); plt.close() - plt.scatter(sdists, deltas); img_sdists = wandb.Image(plt); plt.close() - wb_log_metrics(trainer.logger, { - 'fdists_vs_overlap': img_fdists, - 'sdists_vs_overlap': img_sdists, - }) - - # done! - return [ - ImShowCallback(every_n_steps=cfg.settings.exp.show_every_n_steps, begin_first_step=True), - DistsPlotCallback(every_n_steps=cfg.settings.exp.show_every_n_steps, begin_first_step=True), - StatsShowCallback(every_n_steps=cfg.settings.exp.show_every_n_steps, begin_first_step=True), - ] - - -# ========================================================================= # -# Run Hydra # -# ========================================================================= # - - -ROOT_DIR = os.path.abspath(__file__ + '/../../..') - - -def run_gen_adversarial_dataset(cfg): - time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') - log.info(f'Starting run at time: {time_string}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # cleanup from old runs: - try: - wandb.finish() - except: - pass - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - cfg = make_non_strict(cfg) - # - - - - - - - - - - - - - - - # - # check CUDA setting - gpus = hydra_get_gpus(cfg) - # create logger - logger = hydra_make_logger(cfg) - # create callbacks - callbacks: List[pl.Callback] = [c for c in hydra_get_callbacks(cfg) if isinstance(c, LoggerProgressCallback)] - # - - - - - - - - - - - - - - - # - # check save dirs - assert not os.path.isabs(cfg.settings.exp.rel_save_dir), f'rel_save_dir must be relative: {repr(cfg.settings.exp.rel_save_dir)}' - save_dir = os.path.join(ROOT_DIR, cfg.settings.exp.rel_save_dir) - assert os.path.isabs(save_dir), f'save_dir must be absolute: {repr(save_dir)}' - # - - - - - - - - - - - - - - - # - # get the logger and initialize - if logger is not None: - logger.log_hyperparams(cfg) - # print the final config! - log.info('Final Config' + make_box_str(OmegaConf.to_yaml(cfg))) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # | | | | | | | | | | | | | | | # - seed(cfg.settings.job.seed) - # | | | | | | | | | | | | | | | # - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # make framework - framework = AdversarialModel(**cfg.adv_system) - callbacks.extend(framework.make_train_periodic_callbacks(cfg)) - # train - trainer = pl.Trainer( - logger=logger, - callbacks=callbacks, - # cfg.dsettings.trainer - gpus=gpus, - # cfg.trainer - max_epochs=cfg.trainer.max_epochs, - max_steps=cfg.trainer.max_steps, - log_every_n_steps=cfg.trainer.log_every_n_steps, - flush_logs_every_n_steps=cfg.trainer.flush_logs_every_n_steps, - progress_bar_refresh_rate=cfg.trainer.progress_bar_refresh_rate, - prepare_data_per_node=cfg.trainer.prepare_data_per_node, - # we do this here so we don't run the final metrics - terminate_on_nan=True, - checkpoint_callback=False, - ) - trainer.fit(framework) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # get save paths - save_prefix = f'{cfg.settings.exp.save_prefix}_' if cfg.settings.exp.save_prefix else '' - save_path_model = os.path.join(save_dir, f'{save_prefix}{time_string}_{cfg.settings.job.name}', f'model.pt') - save_path_data = os.path.join(save_dir, f'{save_prefix}{time_string}_{cfg.settings.job.name}', f'data.h5') - # create directories - if cfg.settings.exp.save_model: ensure_parent_dir_exists(save_path_model) - if cfg.settings.exp.save_data: ensure_parent_dir_exists(save_path_data) - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # save adversarial model - if cfg.settings.exp.save_model: - log.info(f'saving model to path: {repr(save_path_model)}') - torch.save(framework.model, save_path_model) - log.info(f'saved model size: {bytes_to_human(os.path.getsize(save_path_model))}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - # save adversarial dataset - if cfg.settings.exp.save_data: - log.info(f'saving data to path: {repr(save_path_data)}') - # transfer to GPU - if torch.cuda.is_available(): - framework = framework.cuda() - # create new h5py file -- TODO: use this in other places! - with H5Builder(path=save_path_data, mode='atomic_w') as builder: - # this dataset is self-contained and can be loaded by SelfContainedHdf5GroundTruthData - builder.add_dataset_from_gt_data( - data=framework.dataset, # produces tensors - mutator=wrapped_partial(framework.batch_to_adversarial_imgs, mode=cfg.settings.exp.save_dtype), # consumes tensors -> np.ndarrays - img_shape=(64, 64, None), - compression_lvl=4, - dtype=cfg.settings.exp.save_dtype, - batch_size=32, - ) - log.info(f'saved data size: {bytes_to_human(os.path.getsize(save_path_data))}') - # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # - - -# ========================================================================= # -# Entry Point # -# ========================================================================= # - - -if __name__ == '__main__': - - # BENCHMARK (batch_size=256, optimizer=sgd, lr=1e-2, dataset_num_workers=0): - # - batch_optimizer=False, gpu=True, fp16=True : [3168MiB/5932MiB, 3.32/11.7G, 5.52it/s] - # - batch_optimizer=False, gpu=True, fp16=False : [5248MiB/5932MiB, 3.72/11.7G, 4.84it/s] - # - batch_optimizer=False, gpu=False, fp16=True : [same as fp16=False] - # - batch_optimizer=False, gpu=False, fp16=False : [0003MiB/5932MiB, 4.60/11.7G, 1.05it/s] - # --------- - # - batch_optimizer=True, gpu=True, fp16=True : [1284MiB/5932MiB, 3.45/11.7G, 4.31it/s] - # - batch_optimizer=True, gpu=True, fp16=False : [1284MiB/5932MiB, 3.72/11.7G, 4.31it/s] - # - batch_optimizer=True, gpu=False, fp16=True : [same as fp16=False] - # - batch_optimizer=True, gpu=False, fp16=False : [0003MiB/5932MiB, 1.80/11.7G, 4.18it/s] - - # BENCHMARK (batch_size=1024, optimizer=sgd, lr=1e-2, dataset_num_workers=12): - # - batch_optimizer=True, gpu=True, fp16=True : [2510MiB/5932MiB, 4.10/11.7G, 4.75it/s, 20% gpu util] (to(device).to(dtype)) - # - batch_optimizer=True, gpu=True, fp16=True : [2492MiB/5932MiB, 4.10/11.7G, 4.12it/s, 19% gpu util] (to(device, dtype)) - - @hydra.main(config_path=os.path.join(ROOT_DIR, 'experiment/config'), config_name="config_adversarial_dataset_approx") - def main(cfg): - try: - run_gen_adversarial_dataset(cfg) - except Exception as e: - # truncate error - err_msg = str(e) - err_msg = err_msg[:244] + ' ' if len(err_msg) > 244 else err_msg - # log something at least - log.error(f'exiting: experiment error | {err_msg}', exc_info=True) - - # EXP ARGS: - # $ ... -m dataset=smallnorb,shapes3d - try: - main() - except KeyboardInterrupt as e: - log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) - except Exception as e: - log_error_and_exit(err_type='hydra error', err_msg=str(e)) diff --git a/research/e06_adversarial_data/util_eval_adversarial.py b/research/e06_adversarial_data/util_eval_adversarial.py deleted file mode 100644 index 8b4c9b5c..00000000 --- a/research/e06_adversarial_data/util_eval_adversarial.py +++ /dev/null @@ -1,348 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from typing import Tuple - -import numpy as np -from numba import njit -from scipy.stats import gmean - - -# ========================================================================= # -# Aggregate # -# ========================================================================= # - - -_NP_AGGREGATE_FNS = { - 'sum': np.sum, - 'mean': np.mean, - 'gmean': gmean, # no negatives - 'max': lambda a, axis, dtype: np.amax(a, axis=axis), # propagate NaNs - 'min': lambda a, axis, dtype: np.amin(a, axis=axis), # propagate NaNs - 'std': np.std, -} - - -def np_aggregate(array, mode: str, axis=0, dtype=None): - try: - fn = _NP_AGGREGATE_FNS[mode] - except KeyError: - raise KeyError(f'invalid aggregate mode: {repr(mode)}, must be one of: {sorted(_NP_AGGREGATE_FNS.keys())}') - result = fn(array, axis=axis, dtype=dtype) - if dtype is not None: - result = result.astype(dtype) - return result - - -# ========================================================================= # -# Factor Evaluation - SLOW # -# ========================================================================= # - - -def eval_factor_fitness_numpy( - individual: np.ndarray, - f_idx: int, - f_dist_matrices: np.ndarray, - factor_sizes: Tuple[int, ...], - fitness_mode: str, - exclude_diag: bool, - increment_single: bool = True, -) -> float: - assert increment_single, f'`increment_single=False` is not supported for numpy fitness evaluation' - # generate missing mask axis - mask = individual.reshape(factor_sizes) - mask = np.moveaxis(mask, f_idx, -1) - f_mask = mask[..., :, None] & mask[..., None, :] - # the diagonal can change statistics - if exclude_diag: - diag = np.arange(f_mask.shape[-1]) - f_mask[..., diag, diag] = False - # mask the distance array | we negate the mask so that TRUE means the item is disabled - f_dists = np.ma.masked_where(~f_mask, f_dist_matrices) - - # get distances - if fitness_mode == 'range': agg_vals = np.ma.max(f_dists, axis=-1) - np.ma.min(f_dists, axis=-1) - elif fitness_mode == 'max': agg_vals = np.ma.max(f_dists, axis=-1) - elif fitness_mode == 'std': agg_vals = np.ma.std(f_dists, axis=-1) - else: raise KeyError(f'invalid fitness_mode: {repr(fitness_mode)}') - - # mean -- there is still a slight difference between this version - # and the numba version, but this helps improve things... - # It might just be a precision error? - fitness_sparse = np.ma.masked_where(~mask, agg_vals).mean() - - # combined scores - return fitness_sparse - - -# ========================================================================= # -# Factor Evaluation - FAST # -# ========================================================================= # - - -@njit -def eval_factor_fitness_numba__std_nodiag( - mask: np.ndarray, - f_dists: np.ndarray, - increment_single: bool = True -): - """ - This is about 10x faster than the built in numpy version - """ - assert f_dists.shape == (*mask.shape, mask.shape[-1]) - # totals - total = 0.0 - count = 0 - # iterate over values -- np.ndindex is usually quite fast - for I in np.ndindex(mask.shape[:-1]): - # mask is broadcast to the distance matrix - m_row = mask[I] - d_mat = f_dists[I] - # handle each distance matrix -- enumerate is usually faster than range - for i, m in enumerate(m_row): - if not m: - continue - # get vars - dists = d_mat[i] - # init vars - n = 0 - s = 0.0 - s2 = 0.0 - # handle each row -- enumerate is usually faster than range - for j, d in enumerate(dists): - if i == j: - continue - if not m_row[j]: - continue - n += 1 - s += d - s2 += d*d - # ^^^ END j - # update total - if n > 1: - mean2 = (s * s) / (n * n) - m2 = (s2 / n) - # is this just needed because of precision errors? - if m2 > mean2: - total += np.sqrt(m2 - mean2) - count += 1 - elif increment_single and (n == 1): - total += 0. - count += 1 - # ^^^ END i - if count == 0: - return -1 - else: - return total / count - - -@njit -def eval_factor_fitness_numba__range_nodiag( - mask: np.ndarray, - f_dists: np.ndarray, - increment_single: bool = True, -): - """ - This is about 10x faster than the built in numpy version - """ - assert f_dists.shape == (*mask.shape, mask.shape[-1]) - # totals - total = 0.0 - count = 0 - # iterate over values -- np.ndindex is usually quite fast - for I in np.ndindex(mask.shape[:-1]): - # mask is broadcast to the distance matrix - m_row = mask[I] - d_mat = f_dists[I] - # handle each distance matrix -- enumerate is usually faster than range - for i, m in enumerate(m_row): - if not m: - continue - # get vars - dists = d_mat[i] - # init vars - num_checked = False - m = 0.0 - M = 0.0 - # handle each row -- enumerate is usually faster than range - for j, d in enumerate(dists): - if i == j: - continue - if not m_row[j]: - continue - # update range - if num_checked > 0: - if d < m: - m = d - if d > M: - M = d - else: - m = d - M = d - # update num checked - num_checked += 1 - # ^^^ END j - # update total - if (num_checked > 1) or (increment_single and num_checked == 1): - total += (M - m) - count += 1 - # ^^^ END i - if count == 0: - return -1 - else: - return total / count - - -def eval_factor_fitness_numba( - individual: np.ndarray, - f_idx: int, - f_dist_matrices: np.ndarray, - factor_sizes: Tuple[int, ...], - fitness_mode: str, - exclude_diag: bool, - increment_single: bool = True, -): - """ - We only keep this function as a compatibility layer between: - - eval_factor_fitness_numpy - - eval_factor_fitness_numba__range_nodiag - """ - assert exclude_diag, 'fast version of eval only supports `exclude_diag=True`' - # usually a view - mask = np.moveaxis(individual.reshape(factor_sizes), f_idx, -1) - # call - if fitness_mode == 'range': - return eval_factor_fitness_numba__range_nodiag(mask=mask, f_dists=f_dist_matrices, increment_single=increment_single) - elif fitness_mode == 'std': - return eval_factor_fitness_numba__std_nodiag(mask=mask, f_dists=f_dist_matrices, increment_single=increment_single) - else: - raise KeyError(f'fast version of eval only supports `fitness_mode in ("range", "std")`, got: {repr(fitness_mode)}') - - -# ========================================================================= # -# Individual Evaluation # -# ========================================================================= # - - -_EVAL_BACKENDS = { - 'numpy': eval_factor_fitness_numpy, - 'numba': eval_factor_fitness_numba, -} - - -def eval_individual( - individual: np.ndarray, - gt_dist_matrices: np.ndarray, - factor_sizes: Tuple[int, ...], - fitness_overlap_mode: str, - fitness_overlap_aggregate: str, - exclude_diag: bool, - increment_single: bool = True, - backend: str = 'numba', -) -> Tuple[float, float]: - # get function - if backend not in _EVAL_BACKENDS: - raise KeyError(f'invalid backend: {repr(backend)}, must be one of: {sorted(_EVAL_BACKENDS.keys())}') - eval_fn = _EVAL_BACKENDS[backend] - # evaluate all factors - factor_scores = np.array([ - [eval_fn(individual, f_idx, f_dist_matrices, factor_sizes=factor_sizes, fitness_mode=fitness_overlap_mode, exclude_diag=exclude_diag, increment_single=increment_single)] - for f_idx, f_dist_matrices in enumerate(gt_dist_matrices) - ]) - # aggregate - factor_score = np_aggregate(factor_scores[:, 0], mode=fitness_overlap_aggregate, dtype='float64') - kept_ratio = individual.mean() - # check values just in case something goes wrong! - factor_score = np.nan_to_num(factor_score, nan=float('-inf')) - kept_ratio = np.nan_to_num(kept_ratio, nan=float('-inf')) - # return values! - return float(factor_score), float(kept_ratio) - - -# ========================================================================= # -# Equality Checks # -# ========================================================================= # - - -def _check_equal( - dataset_name: str = 'dsprites', - fitness_mode: str = 'std', # range, std - n: int = 5, -): - from research.e01_visual_overlap.util_compute_traversal_dists import cached_compute_all_factor_dist_matrices - from timeit import timeit - import research.util as H - - # load data - gt_data = H.make_data(dataset_name) - print(f'{dataset_name} {gt_data.factor_sizes} : {fitness_mode}') - - # get distances & individual - all_dist_matrices = cached_compute_all_factor_dist_matrices(dataset_name) # SHAPE FOR: s=factor_sizes, i=f_idx | (*s[:i], *s[i+1:], s[i], s[i]) - mask = np.random.random(len(gt_data)) < 0.5 # SHAPE: (-1,) - - def eval_factor(backend: str, f_idx: int, increment_single=True): - return _EVAL_BACKENDS[backend]( - individual=mask, - f_idx=f_idx, - f_dist_matrices=all_dist_matrices[f_idx], - factor_sizes=gt_data.factor_sizes, - fitness_mode=fitness_mode, - exclude_diag=True, - increment_single=increment_single, - ) - - def eval_all(backend: str, increment_single=True): - return np.around([eval_factor(backend, i, increment_single=increment_single) for i in range(gt_data.num_factors)], decimals=15) - - new_vals = eval_all('numba', increment_single=False) - new_time = timeit(lambda: eval_all('numba', increment_single=False), number=n) / n - print(f'- NEW {new_time:.5f}s {new_vals} (increment_single=False)') - - new_vals = eval_all('numba') - new_time = timeit(lambda: eval_all('numba'), number=n) / n - print(f'- NEW {new_time:.5f}s {new_vals}') - - old_vals = eval_all('numpy') - old_time = timeit(lambda: eval_all('numpy'), number=n) / n - print(f'- OLD {old_time:.5f}s {old_vals}') - print(f'* speedup: {np.around(old_time/new_time, decimals=2)}x') - - if not np.allclose(new_vals, old_vals): - print('[WARNING]: values are not close!') - - -if __name__ == '__main__': - - for dataset_name in ['smallnorb', 'shapes3d', 'dsprites']: - print('='*100) - _check_equal(dataset_name, fitness_mode='std') - print() - _check_equal(dataset_name, fitness_mode='range') - print('='*100) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e06_adversarial_data/util_eval_adversarial_dist_pairs.py b/research/e06_adversarial_data/util_eval_adversarial_dist_pairs.py deleted file mode 100644 index 0146b254..00000000 --- a/research/e06_adversarial_data/util_eval_adversarial_dist_pairs.py +++ /dev/null @@ -1,291 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -from typing import Tuple - -import numpy as np -from numba import njit - - -# ========================================================================= # -# Factor Evaluation - SLOW # -# ========================================================================= # -from disent.util.profiling import Timer - - -def eval_dist_pairs_numpy( - mask: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - fitness_mode: str, - increment_single: bool = True -) -> float: - assert increment_single, f'`increment_single=False` is not supported for numpy fitness evaluation' - # mask the distance array | we negate the mask so that TRUE means the item is disabled - dists = np.ma.masked_where(~mask[pair_obs_idxs], pair_obs_dists) - # get distances - if fitness_mode == 'range': agg_vals = np.ma.max(dists, axis=-1) - np.ma.min(dists, axis=-1) - elif fitness_mode == 'std': agg_vals = np.ma.std(dists, axis=-1) - else: raise KeyError(f'invalid fitness_mode: {repr(fitness_mode)}') - # mean -- there is still a slight difference between this version - # and the numba version, but this helps improve things... - # It might just be a precision error? - fitness_sparse = np.ma.masked_where(~mask, agg_vals).mean() - # combined scores - return fitness_sparse - - -# ========================================================================= # -# Factor Evaluation - FAST # -# ========================================================================= # - - -@njit -def eval_dist_pairs_numba__std( - mask: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - increment_single: bool = True -): - """ - This is about 10x faster than the built in numpy version - -- something is wrong compared to the numpy version, maybe the - numpy version is wrong because of the mean taken after masking? - """ - assert len(mask) == len(pair_obs_dists) - assert len(mask) == len(pair_obs_idxs) - assert pair_obs_dists.shape == pair_obs_idxs.shape - # totals - total = 0.0 - count = 0 - # iterate over values -- np.ndindex is usually quite fast - for i, m in enumerate(mask): - # skip if invalid - if not m: - continue - # get pair info - dists = pair_obs_dists[i] - idxs = pair_obs_idxs[i] - # init vars - n = 0 - s = 0.0 - s2 = 0.0 - # handle each distance matrix -- enumerate is usually faster than range - for j, d in zip(idxs, dists): - # skip if invalid - if not mask[j]: - continue - # compute std - n += 1 - s += d - s2 += d*d - # update total -- TODO: numpy includes this, but we might not want to? - if n > 1: - mean2 = (s * s) / (n * n) - m2 = (s2 / n) - # is this just needed because of precision errors? - if m2 > mean2: - total += np.sqrt(m2 - mean2) - count += 1 - elif increment_single and (n == 1): - total += 0. - count += 1 - # ^^^ END i - if count == 0: - return -1 - else: - return total / count - - -@njit -def eval_dist_pairs_numba__range( - mask: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - increment_single: bool = True -): - """ - This is about 10x faster than the built in numpy version - """ - assert len(mask) == len(pair_obs_dists) - assert len(mask) == len(pair_obs_idxs) - assert pair_obs_dists.shape == pair_obs_idxs.shape - # totals - total = 0.0 - count = 0 - # iterate over values -- np.ndindex is usually quite fast - for i, m in enumerate(mask): - # skip if invalid - if not m: - continue - # get pair info - dists = pair_obs_dists[i] - idxs = pair_obs_idxs[i] - # init vars - num_checked = 0 - m = 0.0 - M = 0.0 - # handle each distance matrix -- enumerate is usually faster than range - for j, d in zip(idxs, dists): - # skip if invalid - if not mask[j]: - continue - # update range - if num_checked > 0: - if d < m: m = d - if d > M: M = d - else: - m = d - M = d - # update num checked - num_checked += 1 - # update total - if (num_checked > 1) or (increment_single and num_checked == 1): - total += (M - m) - count += 1 - # ^^^ END i - if count == 0: - return -1 - else: - return total / count - - -def eval_dist_pairs_numba( - mask: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - fitness_mode: str, - increment_single: bool = True -): - """ - We only keep this function as a compatibility layer between: - - eval_numpy - - eval_numba__range_nodiag - """ - # call - if fitness_mode == 'range': - return eval_dist_pairs_numba__range(mask=mask, pair_obs_dists=pair_obs_dists, pair_obs_idxs=pair_obs_idxs, increment_single=increment_single) - elif fitness_mode == 'std': - return eval_dist_pairs_numba__std(mask=mask, pair_obs_dists=pair_obs_dists, pair_obs_idxs=pair_obs_idxs, increment_single=increment_single) - else: - raise KeyError(f'fast version of eval only supports `fitness_mode in ("range", "std")`, got: {repr(fitness_mode)}') - - -# ========================================================================= # -# Individual Evaluation # -# ========================================================================= # - - -_EVAL_BACKENDS = { - 'numpy': eval_dist_pairs_numpy, - 'numba': eval_dist_pairs_numba, -} - - -def eval_masked_dist_pairs( - mask: np.ndarray, - pair_obs_dists: np.ndarray, - pair_obs_idxs: np.ndarray, - fitness_mode: str, - increment_single: bool = True, - backend: str = 'numba', -) -> Tuple[float, float]: - # get function - if backend not in _EVAL_BACKENDS: - raise KeyError(f'invalid backend: {repr(backend)}, must be one of: {sorted(_EVAL_BACKENDS.keys())}') - eval_fn = _EVAL_BACKENDS[backend] - # evaluate - factor_score = eval_fn( - mask=mask, - pair_obs_dists=pair_obs_dists, - pair_obs_idxs=pair_obs_idxs, - fitness_mode=fitness_mode, - increment_single=increment_single, - ) - # aggregate - kept_ratio = mask.mean() - # check values just in case something goes wrong! - factor_score = np.nan_to_num(factor_score, nan=float('-inf')) - kept_ratio = np.nan_to_num(kept_ratio, nan=float('-inf')) - # return values! - return float(factor_score), float(kept_ratio) - - -# ========================================================================= # -# Equality Checks # -# ========================================================================= # - - -def _check_equal( - dataset_name: str = 'dsprites', - pair_mode: str = 'nearby_scaled', - pairs_per_obs: int = 8, - fitness_mode: str = 'std', # range, std - n: int = 5, -): - from research.e01_visual_overlap.util_compute_traversal_dist_pairs import cached_compute_dataset_pair_dists - from timeit import timeit - - # get distances & individual # (len(gt_data), pairs_per_obs) & (len(gt_data),) - obs_pair_idxs, obs_pair_dists = cached_compute_dataset_pair_dists(dataset_name=dataset_name, pair_mode=pair_mode, pairs_per_obs=pairs_per_obs, scaled=True) - mask = np.random.random(len(obs_pair_idxs)) < 0.5 - - def eval_all(backend: str, increment_single=True): - return _EVAL_BACKENDS[backend]( - mask=mask, - pair_obs_dists=obs_pair_dists, - pair_obs_idxs=obs_pair_idxs, - fitness_mode=fitness_mode, - increment_single=increment_single, - ) - - new_vals = eval_all('numba', increment_single=False) - new_time = timeit(lambda: eval_all('numba', increment_single=False), number=n) / n - print(f'- NEW {new_time:.5f}s {new_vals} (increment_single=False)') - - new_vals = eval_all('numba') - new_time = timeit(lambda: eval_all('numba'), number=n) / n - print(f'- NEW {new_time:.5f}s {new_vals}') - - old_vals = eval_all('numpy') - old_time = timeit(lambda: eval_all('numpy'), number=n) / n - print(f'- OLD {old_time:.5f}s {old_vals}') - print(f'* speedup: {np.around(old_time/new_time, decimals=2)}x') - - if not np.allclose(new_vals, old_vals): - print('[WARNING]: values are not close!') - - -if __name__ == '__main__': - - for dataset_name in ['smallnorb', 'shapes3d', 'dsprites']: - print('='*100) - _check_equal(dataset_name, fitness_mode='std') - print() - _check_equal(dataset_name, fitness_mode='range') - print('='*100) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e06_adversarial_data/util_gen_adversarial_dataset.py b/research/e06_adversarial_data/util_gen_adversarial_dataset.py deleted file mode 100644 index 4db566f3..00000000 --- a/research/e06_adversarial_data/util_gen_adversarial_dataset.py +++ /dev/null @@ -1,446 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -""" -General helper utilities for generating -adversarial datasets using triplet sampling. -""" - -import logging -from functools import lru_cache -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import Union - -import numpy as np -import torch - -import research.util as H -from disent.dataset.data import GroundTruthData -from disent.dataset.sampling import BaseDisentSampler -from disent.dataset.sampling import GroundTruthPairSampler -from disent.dataset.sampling import GroundTruthTripleSampler -from disent.dataset.sampling import RandomSampler -from disent.util.strings import colors as c - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# Samplers # -# ========================================================================= # - - -class AdversarialSampler_SwappedRandom(BaseDisentSampler): - - def uninit_copy(self) -> 'AdversarialSampler_SwappedRandom': - return AdversarialSampler_SwappedRandom(swap_metric=self._swap_metric) - - def __init__(self, swap_metric='manhattan'): - super().__init__(3) - assert swap_metric in {'k', 'manhattan', 'manhattan_norm', 'euclidean', 'euclidean_norm'} - self._swap_metric = swap_metric - self._sampler = GroundTruthTripleSampler(swap_metric=swap_metric) - self._gt_data: GroundTruthData = None - - def _init(self, gt_data: GroundTruthData): - self._sampler.init(gt_data) - self._gt_data = gt_data - - def _sample_idx(self, idx: int) -> Tuple[int, ...]: - anchor, pos, neg = self._gt_data.idx_to_pos([ - idx, - *np.random.randint(0, len(self._gt_data), size=2) - ]) - # swap values - pos, neg = self._sampler._swap_factors(anchor_factors=anchor, positive_factors=pos, negative_factors=neg) - # return triple - return tuple(self._gt_data.pos_to_idx([anchor, pos, neg])) - - -class AdversarialSampler_CloseFar(BaseDisentSampler): - - def uninit_copy(self) -> 'AdversarialSampler_CloseFar': - return AdversarialSampler_CloseFar( - p_k_range=self._p_k_range, - p_radius_range=self._p_radius_range, - n_k_range=self._n_k_range, - n_radius_range=self._n_radius_range, - ) - - def __init__( - self, - p_k_range=(1, 1), - p_radius_range=(1, 1), - n_k_range=(1, -1), - n_radius_range=(1, -1), - ): - super().__init__(3) - self._p_k_range = p_k_range - self._p_radius_range = p_radius_range - self._n_k_range = n_k_range - self._n_radius_range = n_radius_range - self.sampler_close = GroundTruthPairSampler(p_k_range=p_k_range, p_radius_range=p_radius_range) - self.sampler_far = GroundTruthPairSampler(p_k_range=n_k_range, p_radius_range=n_radius_range) - - def _init(self, gt_data: GroundTruthData): - self.sampler_close.init(gt_data) - self.sampler_far.init(gt_data) - - def _sample_idx(self, idx: int) -> Tuple[int, ...]: - # sample indices - anchor, pos = self.sampler_close(idx) - _anchor, neg = self.sampler_far(idx) - assert anchor == _anchor - # return triple - return anchor, pos, neg - - -class AdversarialSampler_SameK(BaseDisentSampler): - - def uninit_copy(self) -> 'AdversarialSampler_SameK': - return AdversarialSampler_SameK( - k=self._k, - sample_p_close=self._sample_p_close, - ) - - def __init__(self, k: Union[Literal['random'], int] = 'random', sample_p_close: bool = False): - super().__init__(3) - self._gt_data: GroundTruthData = None - self._sample_p_close = sample_p_close - self._k = k - assert (isinstance(k, int) and k > 0) or (k == 'random') - - def _init(self, gt_data: GroundTruthData): - self._gt_data = gt_data - - def _sample_idx(self, idx: int) -> Tuple[int, ...]: - a_factors = self._gt_data.idx_to_pos(idx) - # SAMPLE FACTOR INDICES - k = self._k - if k == 'random': - k = np.random.randint(1, self._gt_data.num_factors+1) # end exclusive, ie. [1, num_factors+1) - # get shared mask - shared_indices = np.random.choice(self._gt_data.num_factors, size=self._gt_data.num_factors-k, replace=False) - shared_mask = np.zeros(a_factors.shape, dtype='bool') - shared_mask[shared_indices] = True - # generate values - p_factors = self._sample_shared(a_factors, shared_mask, sample_close=self._sample_p_close) - n_factors = self._sample_shared(a_factors, shared_mask, sample_close=False) - # swap values if wrong - # TODO: this might give errors! - # - one factor might be less than another - if np.sum(np.abs(a_factors - p_factors)) > np.sum(np.abs(a_factors - n_factors)): - p_factors, n_factors = n_factors, p_factors - # check values - assert np.sum(a_factors != p_factors) == k, 'this should never happen!' - assert np.sum(a_factors != n_factors) == k, 'this should never happen!' - # return values - return tuple(self._gt_data.pos_to_idx([ - a_factors, - p_factors, - n_factors, - ])) - - def _sample_shared(self, base_factors, shared_mask, tries=100, sample_close: bool = False): - sampled_factors = base_factors.copy() - generate_mask = ~shared_mask - # generate values - for i in range(tries): - if sample_close: - sampled_values = (base_factors + np.random.randint(-1, 1+1, size=self._gt_data.num_factors)) - sampled_values = np.clip(sampled_values, 0, np.array(self._gt_data.factor_sizes) - 1)[generate_mask] - else: - sampled_values = np.random.randint(0, np.array(self._gt_data.factor_sizes)[generate_mask]) - # overwrite values that are not different - sampled_factors[generate_mask] = sampled_values - # update mask - sampled_shared_mask = (sampled_factors == base_factors) - generate_mask &= sampled_shared_mask - # check everything - if np.sum(sampled_shared_mask) == np.sum(shared_mask): - assert np.sum(generate_mask) == 0 - return sampled_factors - # we need to try again! - raise RuntimeError('could not generate factors: {}') - - -def sampler_print_test(sampler: Union[str, BaseDisentSampler], gt_data: GroundTruthData = None, steps=100): - # make data - if gt_data is None: - gt_data = H.make_dataset('xysquares_8x8_mini').gt_data - # make sampler - if isinstance(sampler, str): - prefix = sampler - sampler = make_adversarial_sampler(sampler) - else: - prefix = sampler.__class__.__name__ - if not sampler.is_init: - sampler.init(gt_data) - # print everything - count_pn_k0, count_pn_d0 = 0, 0 - for i in range(min(steps, len(gt_data))): - a, p, n = gt_data.idx_to_pos(sampler(i)) - ap_k = np.sum(a != p); ap_d = np.sum(np.abs(a - p)) - an_k = np.sum(a != n); an_d = np.sum(np.abs(a - n)) - pn_k = np.sum(p != n); pn_d = np.sum(np.abs(p - n)) - print(f'{prefix}: [{c.lGRN}ap{c.RST}:{ap_k:2d}:{ap_d:2d}] [{c.lRED}an{c.RST}:{an_k:2d}:{an_d:2d}] [{c.lYLW}pn{c.RST}:{pn_k:2d}:{pn_d:2d}] {a} {p} {n}') - count_pn_k0 += (pn_k == 0) - count_pn_d0 += (pn_d == 0) - print(f'count pn:(k=0) = {count_pn_k0} pn:(d=0) = {count_pn_d0}') - - -def make_adversarial_sampler(mode: str = 'close_far'): - if mode in ['random_swap_k', 'random_swap_manhattan', 'random_swap_manhattan_norm', 'random_swap_euclidean', 'random_swap_euclidean_norm']: - # NOTE # -- random_swap_manhattan -- probability is too low of encountering nearby obs, don't use this! - metric = mode[len('random_swap_'):] - return AdversarialSampler_SwappedRandom(swap_metric=metric) - elif mode in ['close_far', 'close_p_random_n']: - # *NB* # - return AdversarialSampler_CloseFar( - p_k_range=(1, 1), n_k_range=(1, -1), - p_radius_range=(1, 1), n_radius_range=(1, -1), - ) - elif mode in ['close_far_random', 'close_p_random_n_bb']: - # *NB* # - return GroundTruthTripleSampler( - p_k_range=(1, 1), n_k_range=(1, -1), n_k_sample_mode='bounded_below', n_k_is_shared=True, - p_radius_range=(1, 1), n_radius_range=(1, -1), n_radius_sample_mode='bounded_below', - ) - elif mode in ['same_k']: - # *NB* # - return AdversarialSampler_SameK(k='random', sample_p_close=False) - elif mode in ['same_k_close']: - # *NB* # - return AdversarialSampler_SameK(k='random', sample_p_close=True) - elif mode in ['same_k1_close']: - # *NB* # - return AdversarialSampler_SameK(k=1, sample_p_close=True) - elif mode == 'close_factor_far_random': - return GroundTruthTripleSampler( - p_k_range=(1, 1), n_k_range=(1, -1), n_k_sample_mode='bounded_below', n_k_is_shared=True, - p_radius_range=(1, -1), n_radius_range=(0, -1), n_radius_sample_mode='bounded_below', - ) - elif mode == 'close_far_same_factor': - # TODO: problematic for dsprites - return GroundTruthTripleSampler( - p_k_range=(1, 1), n_k_range=(1, 1), n_k_sample_mode='bounded_below', n_k_is_shared=True, - p_radius_range=(1, 1), n_radius_range=(2, -1), n_radius_sample_mode='bounded_below', - ) - elif mode == 'same_factor': - return GroundTruthTripleSampler( - p_k_range=(1, 1), n_k_range=(1, 1), n_k_sample_mode='bounded_below', n_k_is_shared=True, - p_radius_range=(1, -2), n_radius_range=(2, -1), n_radius_sample_mode='bounded_below', # bounded below does not always work, still relies on random chance :/ - ) - elif mode == 'random_bb': - return GroundTruthTripleSampler( - p_k_range=(0, -1), n_k_range=(0, -1), n_k_sample_mode='bounded_below', n_k_is_shared=True, - p_radius_range=(0, -1), n_radius_range=(0, -1), n_radius_sample_mode='bounded_below', - ) - elif mode == 'random_swap_manhat': - return GroundTruthTripleSampler( - p_k_range=(0, -1), n_k_range=(0, -1), n_k_sample_mode='random', n_k_is_shared=False, - p_radius_range=(0, -1), n_radius_range=(0, -1), n_radius_sample_mode='random', - swap_metric='manhattan' - ) - elif mode == 'random_swap_manhat_norm': - return GroundTruthTripleSampler( - p_k_range=(0, -1), n_k_range=(0, -1), n_k_sample_mode='random', n_k_is_shared=False, - p_radius_range=(0, -1), n_radius_range=(0, -1), n_radius_sample_mode='random', - swap_metric='manhattan_norm' - ) - elif mode == 'random': - return RandomSampler(num_samples=3) - else: - raise KeyError(f'invalid adversarial sampler: mode={repr(mode)}') - - -# ========================================================================= # -# Adversarial Sort # -# ========================================================================= # - - -@torch.no_grad() -def sort_samples(a_x: torch.Tensor, p_x: torch.Tensor, n_x: torch.Tensor, sort_mode: str = 'none', pixel_loss_mode: str = 'mse'): - # NOTE: this function may mutate its inputs, however - # the returned values should be used. - # do not sort! - if sort_mode == 'none': - return (a_x, p_x, n_x) - elif sort_mode == 'swap': - return (a_x, n_x, p_x) - # compute deltas - p_deltas = H.pairwise_loss(a_x, p_x, mode=pixel_loss_mode, mean_dtype=torch.float32, mask=None) - n_deltas = H.pairwise_loss(a_x, n_x, mode=pixel_loss_mode, mean_dtype=torch.float32, mask=None) - # get swap mask - if sort_mode == 'sort_inorder': swap_mask = p_deltas > n_deltas - elif sort_mode == 'sort_reverse': swap_mask = p_deltas < n_deltas - else: raise KeyError(f'invalid sort_mode: {repr(sort_mode)}, must be one of: ["none", "swap", "sort_inorder", "sort_reverse"]') - # handle mutate or copy - idx_swap = torch.where(swap_mask) - # swap memory values -- TODO: `p_x[idx_swap], n_x[idx_swap] = n_x[idx_swap], p_x[idx_swap]` is this fine? - temp = torch.clone(n_x[idx_swap]) - n_x[idx_swap] = p_x[idx_swap] - p_x[idx_swap] = temp - # done! - return (a_x, p_x, n_x) - - -# ========================================================================= # -# Adversarial Loss # -# ========================================================================= # - -# anchor, positive, negative -TensorTriple = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - - -def _get_triple(x: TensorTriple, adversarial_swapped: bool): - if not adversarial_swapped: - a, p, n = x - else: - a, n, p = x - return a, p, n - - -_MARGIN_MODES = { - 'invert_margin', - 'triplet_margin', -} - - -@lru_cache() -def _parse_margin_mode(adversarial_mode: str): - # parse the MARGIN_MODES -- linear search - for margin_mode in _MARGIN_MODES: - if adversarial_mode == margin_mode: - raise KeyError(f'`{margin_mode}` is not valid, specify the margin in the name, eg. `{margin_mode}_0.01`') - elif adversarial_mode.startswith(f'{margin_mode}_'): - margin = float(adversarial_mode[len(f'{margin_mode}_'):]) - return margin_mode, margin - # done! - return adversarial_mode, None - - -def adversarial_loss( - ys: TensorTriple, - xs: Optional[TensorTriple] = None, # only used if mask_deltas==True - # adversarial loss settings - adversarial_mode: str = 'invert_shift', - adversarial_swapped: bool = False, - adversarial_masking: bool = False, # requires `xs` to be set - adversarial_top_k: Optional[int] = None, - # pixel loss to get deltas settings - pixel_loss_mode: str = 'mse', - # statistics - return_stats: bool = False, -): - a_y, p_y, n_y = _get_triple(ys, adversarial_swapped=adversarial_swapped) - - # get mask - if adversarial_masking: - a_x, p_x, n_x = _get_triple(xs, adversarial_swapped=adversarial_swapped) - ap_mask, an_mask = (a_x != p_x), (a_x != n_x) - else: - ap_mask, an_mask = None, None - - # compute deltas - p_deltas = H.pairwise_loss(a_y, p_y, mode=pixel_loss_mode, mean_dtype=torch.float32, mask=ap_mask) - n_deltas = H.pairwise_loss(a_y, n_y, mode=pixel_loss_mode, mean_dtype=torch.float32, mask=an_mask) - deltas = (n_deltas - p_deltas) - - # parse mode - adversarial_mode, margin = _parse_margin_mode(adversarial_mode) - - # compute loss deltas - # AUTO-CONSTANT - if adversarial_mode == 'self': loss_deltas = torch.abs(deltas) - elif adversarial_mode == 'self_random': - # the above should be equivalent with the right sampling strategy? - all_deltas = torch.cat([p_deltas, n_deltas], dim=0) - indices = np.arange(len(all_deltas)) - np.random.shuffle(indices) - deltas = all_deltas[indices[len(deltas):]] - all_deltas[indices[:len(deltas)]] - loss_deltas = torch.abs(deltas) - # INVERT - elif adversarial_mode == 'invert': loss_deltas = torch.maximum(deltas, torch.zeros_like(deltas)) - elif adversarial_mode == 'invert_margin': loss_deltas = torch.maximum(margin + deltas, torch.zeros_like(deltas)) # invert_loss = torch.clamp_min(n_dist - p_dist + margin_max, 0) - elif adversarial_mode == 'invert_unbounded': loss_deltas = deltas - # TRIPLET - elif adversarial_mode == 'triplet': loss_deltas = torch.maximum(-deltas, torch.zeros_like(deltas)) - elif adversarial_mode == 'triplet_margin': loss_deltas = torch.maximum(margin - deltas, torch.zeros_like(deltas)) # triplet_loss = torch.clamp_min(p_dist - n_dist + margin_max, 0) - elif adversarial_mode == 'triplet_unbounded': loss_deltas = -deltas - # OTHER - else: - raise KeyError(f'invalid `adversarial_mode`: {repr(adversarial_mode)}') - - # checks - assert deltas.shape == loss_deltas.shape, 'this is a bug' - - # top k deltas - if adversarial_top_k is not None: - loss_deltas = torch.topk(loss_deltas, k=adversarial_top_k, largest=True).values - - # get average loss - loss = loss_deltas.mean() - - # return early - if not return_stats: - return loss - - # compute stats! - with torch.no_grad(): - loss_stats = { - 'stat/p_delta:mean': float(p_deltas.mean().cpu()), 'stat/p_delta:std': float(p_deltas.std().cpu()), - 'stat/n_delta:mean': float(n_deltas.mean().cpu()), 'stat/n_delta:std': float(n_deltas.std().cpu()), - 'stat/deltas:mean': float(loss_deltas.mean().cpu()), 'stat/deltas:std': float(loss_deltas.std().cpu()), - } - - return loss, loss_stats - - -# ========================================================================= # -# END # -# ========================================================================= # - - -# if __name__ == '__main__': -# -# def _main(): -# from disent.dataset.data import XYObjectData -# -# # NB: -# # close_p_random_n -# # close_p_random_n_bb -# # same_k -# # same_k_close -# # same_k1_close -# -# sampler_print_test( -# sampler='close_p_random_n', -# gt_data=XYObjectData() -# ) -# -# _main() diff --git a/research/e06_adversarial_data/util_load_adversarial_mask.py b/research/e06_adversarial_data/util_load_adversarial_mask.py deleted file mode 100644 index 481a1329..00000000 --- a/research/e06_adversarial_data/util_load_adversarial_mask.py +++ /dev/null @@ -1,78 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import gzip -import pickle -import numpy as np -import logging - - -log = logging.getLogger(__name__) - - -# ========================================================================= # -# HELPER # -# ========================================================================= # - - -def get_closest_mask(usage_ratio: float, pickle_file: str, print_n_best: int = 3) -> np.ndarray: - """ - This function is intended to be used with the data - generated by `run_04_gen_adversarial_ruck.py` - - The function finds the closest member in the population with - the matching statistic. The reason this function works is that - the population should consist only of near-pareto-optimal solutions. - - These solutions are found using NSGA2 - - Usage With Hydra Config: - _target_: research.e06_adversarial_data.util_load_adversarial_mask.get_closest_mask - usage_ratio: 0.5 - pickle_file: data.pkl.gz - """ - # load pickled data - with gzip.open(pickle_file, mode='rb') as fp: - data = pickle.load(fp) - values = np.array(data['values'], dtype='bool') - scores = np.array(data['scores'], dtype='float64') - del data - # check shapes - assert values.ndim == 2 - assert scores.ndim == 2 - assert scores.shape == (len(values), 2) - # get closest - best_indices = np.argsort(np.abs(scores[:, 1] - usage_ratio)) - # print stats - if print_n_best > 0: - log.info(f'The {print_n_best} closest members to target usage={usage_ratio:7f}') - for i, idx in enumerate(best_indices[:print_n_best]): - assert np.isclose(np.mean(values[idx]), scores[idx, 1]), 'member fitness_usage is not close to the actual mask usage. The data is invalid.' - log.info(f' [{i+1}] idx={idx:04d} overlap={scores[idx, 0]:7f} usage={scores[idx, 1]:7f}') - # return the best! - return values[best_indices[0]] - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/e07_metric/make_graphs.py b/research/e07_metric/make_graphs.py deleted file mode 100644 index 2ba8e720..00000000 --- a/research/e07_metric/make_graphs.py +++ /dev/null @@ -1,436 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import itertools -import os -from typing import Optional -from typing import Sequence -from typing import Tuple - -import numpy as np -import torch -from matplotlib import cm -from matplotlib import pyplot as plt -from tqdm import tqdm - -import research.util as H -from disent.metrics._flatness_components import compute_axis_score -from disent.metrics._flatness_components import compute_linear_score -from disent.util.seeds import seed - - -# ========================================================================= # -# distance function # -# ========================================================================= # - - -def _rotation_matrix(d, i, j, deg): - assert 0 <= i < j <= d - mat = torch.eye(d, dtype=torch.float32) - r = np.deg2rad(deg) - s, c = np.sin(r), np.cos(r) - mat[i, i] = c - mat[j, j] = c - mat[j, i] = -s - mat[i, j] = s - return mat - - -def rotation_matrix_2d(deg): - return _rotation_matrix(d=2, i=0, j=1, deg=deg) - - -def _random_rotation_matrix(d): - mat = torch.eye(d, dtype=torch.float32) - for i in range(d): - for j in range(i+1, d): - mat @= _rotation_matrix(d, i, j, np.random.randint(0, 360)) - return mat - - -def make_2d_line_points(n: int = 100, deg: float = 30, std_x: float = 1.0, std_y: float = 0.005): - points = torch.randn(n, 2, dtype=torch.float32) * torch.as_tensor([[std_x, std_y]], dtype=torch.float32) - points = points @ rotation_matrix_2d(deg) - return points - - -def make_nd_line_points(n: int = 100, dims: int = 4, std_x: float = 1.0, std_y: float = 0.005): - if not isinstance(dims, int): - m, M = dims - dims = np.randint(m, M) - # generate numbers - xs = torch.randn(n, dims, dtype=torch.float32) - # axis standard deviations - if isinstance(std_y, (float, int)): - std_y = torch.full((dims-1,), fill_value=std_y, dtype=torch.float32) - else: - m, M = std_y - std_y = torch.rand(dims-1, dtype=torch.float32) * (M - m) + m - # scale axes - std = torch.cat([torch.as_tensor([std_x]), std_y]) - xs = xs * std[None, :] - # rotate - return xs @ _random_rotation_matrix(dims) - - -def make_line_points(n: int = 100, deg: float = None, dims: int = 2, std_x: float = 1.0, std_y: float = 0.1): - if deg is None: - return make_nd_line_points(n=n, dims=dims, std_x=std_x, std_y=std_y) - else: - assert dims == 2, f'if "deg" is not None, then "dims" must equal 2, currently set to: {repr(dims)}' - return make_2d_line_points(n=n, deg=deg, std_x=std_x, std_y=std_y) - - -# def random_line(std, n=100): -# std = torch.as_tensor(std, dtype=torch.float32) -# (d,) = std.shape -# # generate numbers -# xs = torch.randn(n, d, dtype=torch.float32) -# # scale axes -# xs = xs * std[None, :] -# # rotate -# return xs @ _random_rotation_matrix(d) - - -# ========================================================================= # -# GAUSSIAN # -# ========================================================================= # - - -def gaussian_1d(x, s): return 1 / (np.sqrt(2 * np.pi) * s) * torch.exp(-(x**2)/(2*s**2)) -def gaussian_1d_dx(x, s): return gaussian_1d(x, s) * (-x/s**2) -def gaussian_1d_dx2(x, s): return gaussian_1d(x, s) * ((x**2 - s**2)/s**4) - - -def gaussian_2d(x, y, sx, sy): return gaussian_1d(x, sx) * gaussian_1d(y, sy) -def gaussian_2d_dy(x, y, sx, sy): return gaussian_1d(x, sx) * gaussian_1d_dx(y, sy) -def gaussian_2d_dy2(x, y, sx, sy): return gaussian_1d(x, sx) * gaussian_1d_dx2(y, sy) - - -def rotated_radius_meshgrid(radius: float, num_points: int, deg: float = 0, device=None, return_orig=False) -> Tuple[torch.Tensor, torch.Tensor]: - # x & y values centered around zero - # p = torch.arange(size, device=device) - (size-1)/2 - p = torch.linspace(-radius, radius, num_points, device=device) - x, y = torch.meshgrid(p, p) - # matrix multiplication along first axis | https://pytorch.org/docs/stable/generated/torch.einsum.html - rx, ry = torch.einsum('dxy,kd->kxy', torch.stack([x, y]), rotation_matrix_2d(deg)) - # result - if return_orig: - return (rx, ry), (x, y) - return rx, ry - - -def rotated_guassian2d(std_x: float, std_y: float, deg: float, trunc_sigma: Optional[float] = None, num_points: int = 511): - radius = (2.25*max(std_x, std_y)) if (trunc_sigma is None) else trunc_sigma - (xs_r, ys_r), (xs, ys) = rotated_radius_meshgrid(radius=radius, num_points=num_points, deg=deg, return_orig=True) - zs = gaussian_2d(xs_r, ys_r, sx=std_x, sy=std_y) - zs /= zs.sum() - return xs, ys, zs - - -def plot_gaussian( - deg: float = 0.0, - std_x: float = 1.0, - std_y: float = 0.1, - # contour - contour_resolution: int = 255, - contour_trunc_sigma: Optional[float] = None, - contour_kwargs: Optional[dict] = None, - # dots - dots_num: Optional[int] = None, - dots_kwargs: Optional[dict] = None, - # axis - ax=None, -): - if ax is None: - fig = plt.figure() - ax = fig.gca() - # set limits - trunc_sigma = (2.05 * max(std_x, std_y)) if (contour_trunc_sigma is None) else contour_trunc_sigma - ax.set_xlim([-trunc_sigma, trunc_sigma]) - ax.set_ylim([-trunc_sigma, trunc_sigma]) - # plot contour - xs, ys, zs = rotated_guassian2d(std_x=std_x, std_y=std_y, deg=deg, trunc_sigma=trunc_sigma, num_points=contour_resolution) - ax.contourf(xs, ys, zs, **({} if contour_kwargs is None else contour_kwargs)) - # plot dots - if dots_num is not None: - points = make_line_points(n=dots_num, dims=2, deg=deg, std_x=std_x, std_y=std_y) - ax.scatter(*points.T, **({} if dots_kwargs is None else dots_kwargs)) - # done - return ax - - -# ========================================================================= # -# Generate Average Plots # -# ========================================================================= # - - -def score_grid( - deg_rotations: Sequence[Optional[float]], - y_std_ratios: Sequence[float], - x_std: float = 1.0, - num_points: int = 1000, - num_dims: int = 2, - use_std: bool = True, - use_max: bool = False, - norm: bool = True, - return_points: bool = False, -): - h, w = len(y_std_ratios), len(deg_rotations) - # grids - axis_scores = torch.zeros([h, w], dtype=torch.float64) - linear_scores = torch.zeros([h, w], dtype=torch.float64) - if return_points: - all_points = torch.zeros([h, w, num_points, num_dims], dtype=torch.float64) - # compute scores - for i, y_std_ratio in enumerate(y_std_ratios): - for j, deg in enumerate(deg_rotations): - points = make_line_points(n=num_points, dims=num_dims, deg=deg, std_x=x_std, std_y=x_std * y_std_ratio) - axis_scores[i, j] = compute_axis_score(points, use_std=use_std, use_max=use_max, norm=norm) - linear_scores[i, j] = compute_linear_score(points, use_std=use_std, use_max=use_max, norm=norm) - if return_points: - all_points[i, j] = points - # results - if return_points: - return axis_scores, linear_scores, all_points - return axis_scores, linear_scores - - -def ave_score_grid( - deg_rotations: Sequence[Optional[float]], - y_std_ratios: Sequence[float], - x_std: float = 1.0, - num_points: int = 1000, - num_dims: int = 2, - use_std: bool = True, - use_max: bool = False, - norm: bool = True, - repeats: int = 10, -): - results = [] - # repeat - for i in tqdm(range(repeats)): - results.append(score_grid(deg_rotations=deg_rotations, y_std_ratios=y_std_ratios, x_std=x_std, num_points=num_points, num_dims=num_dims, use_std=use_std, use_max=use_max, norm=norm)) - # average results - all_axis_scores, all_linear_scores = zip(*results) - axis_scores = torch.mean(torch.stack(all_axis_scores, dim=0), dim=0) - linear_scores = torch.mean(torch.stack(all_linear_scores, dim=0), dim=0) - # results - return axis_scores, linear_scores - - -def make_ave_scores_plot( - std_num: int = 21, - deg_num: int = 21, - ndim: Optional[int] = None, - # extra - num_points: int = 1000, - repeats: int = 25, - x_std: float = 1.0, - use_std: bool = True, - use_max: bool = False, - norm: bool = True, - # cmap - cmap_axis: str = 'GnBu_r', # 'RdPu_r', 'GnBu_r', 'Blues_r', 'viridis', 'plasma', 'magma' - cmap_linear: str = 'RdPu_r', # 'RdPu_r', 'GnBu_r', 'Blues_r', 'viridis', 'plasma', 'magma' - vertical: bool = True, - # subplot settings - subplot_size: float = 4., - subplot_padding: float = 1.5, -): - # make sure to handle the random case - deg_num = std_num if (ndim is None) else deg_num - axis_scores, linear_scores = ave_score_grid( - deg_rotations=np.linspace(0., 180., num=deg_num) if (ndim is None) else [None], - y_std_ratios=np.linspace(0., 1., num=std_num), - x_std=x_std, - num_points=num_points, - num_dims=2 if (ndim is None) else ndim, - use_std=use_std, - use_max=use_max, - norm=norm, - repeats=repeats, - ) - # make plot - fig, axs = H.plt_subplots( - nrows=1+int(vertical), - ncols=1+int(not vertical), - titles=['Linear', 'Axis'], - row_labels=f'$σ_y$ - Standard Deviation', - col_labels=f'θ - Rotation Degrees', - figsize=(subplot_size + 0.5, subplot_size * 2 * (deg_num / std_num) + 0.75)[::1 if vertical else -1] - ) - (ax0, ax1) = axs.flatten() - # subplots - ax0.imshow(linear_scores, cmap=cmap_linear, extent=[0., 180., 1., 0.]) - ax1.imshow(axis_scores, cmap=cmap_axis, extent=[0., 180., 1., 0.]) - for ax in axs.flatten(): - ax.set_aspect(180 * (std_num / deg_num)) - if len(ax.get_xticks()): - ax.set_xticks(np.linspace(0., 180., 5)) - # layout - fig.tight_layout(pad=subplot_padding) - # done - return fig, axs - - -# ========================================================================= # -# HELPER # -# ========================================================================= # - - -def plot_scores(ax, axis_score, linear_score): - from matplotlib.lines import Line2D - assert 0 <= linear_score <= 1 - assert 0 <= axis_score <= 1 - linear_rgb = cm.get_cmap('RdPu_r')(np.clip(linear_score, 0., 1.)) - axis_rgb = cm.get_cmap('GnBu_r')(np.clip(axis_score, 0., 1.)) - ax.legend(handles=[ - Line2D([0], [0], label=f'Linear: {float(linear_score):.2f}', color=linear_rgb, marker='o', markersize=10, linestyle='None'), - Line2D([0], [0], label=f'Axis: {float(axis_score):.2f}', color=axis_rgb, marker='o', markersize=10, linestyle='None'), - ]) - return ax - - -# ========================================================================= # -# Generate Grid Plots # -# ========================================================================= # - - -def make_grid_gaussian_score_plot( - # grid - y_stds: Sequence[float] = (0.8, 0.2, 0.05)[::-1], # (0.8, 0.4, 0.2, 0.1, 0.05), - deg_rotations: Sequence[float] = (0, 22.5, 45, 67.5, 90, 112.5, 135, 157.5), # (0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165), - # plot dot options - dots_num: Optional[int] = None, - # score options - num_points: int = 10000, - repeats: int = 100, - use_std: bool = True, - use_max: bool = False, - norm: bool = True, - # grid options - subplot_size: float = 2.125, - subplot_padding: float = 0.5, - subplot_contour_kwargs: Optional[dict] = None, - subplot_dots_kwargs: Optional[dict] = None, -): - # defaults - if subplot_contour_kwargs is None: subplot_contour_kwargs = dict(cmap='Blues') - if subplot_dots_kwargs is None: subplot_dots_kwargs = dict(cmap='Purples') - - # make figure - nrows, ncols = len(y_stds), len(deg_rotations) - fig, axs = H.plt_subplots( - nrows=nrows, ncols=ncols, - row_labels=[f'$σ_y$ = {std_y}' for std_y in y_stds], - col_labels=[f'θ = {deg}°' for deg in deg_rotations], - hide_axis='all', - figsize=(ncols*subplot_size, nrows*subplot_size), - ) - - # progress - p = tqdm(total=axs.size, desc='generating_plot') - # generate plot - for (y, std_y), (x, deg) in itertools.product(enumerate(y_stds), enumerate(deg_rotations)): - # compute scores - axis_score, linear_score = [], [] - for k in range(repeats): - points = make_2d_line_points(n=num_points, deg=deg, std_x=1.0, std_y=std_y) - axis_score.append(compute_axis_score(points, use_std=use_std, use_max=use_max, norm=norm)) - linear_score.append(compute_linear_score(points, use_std=use_std, use_max=use_max, norm=norm)) - axis_score, linear_score = np.mean(axis_score), np.mean(linear_score) - # generate subplots - plot_gaussian(ax=axs[y, x], deg=deg, std_x=1.0, std_y=std_y, dots_num=dots_num, contour_trunc_sigma=2.05, contour_kwargs=subplot_contour_kwargs, dots_kwargs=subplot_dots_kwargs) - plot_scores(ax=axs[y, x], axis_score=axis_score, linear_score=linear_score) - # update progress - p.update() - plt.tight_layout(pad=subplot_padding) - - return fig, axs - - -# ========================================================================= # -# MAIN # -# ========================================================================= # - - -if __name__ == '__main__': - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - # plot everything - seed(777) - make_grid_gaussian_score_plot( - repeats=250, - num_points=25000, - ) - plt.savefig(H.make_rel_path_add_ext('plots/metric_grid', ext='.png')) - plt.show() - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - # plot everything -- minimal - seed(777) - make_grid_gaussian_score_plot( - y_stds=(0.8, 0.4, 0.2, 0.1, 0.05)[::-1], # (0.8, 0.4, 0.2, 0.1, 0.05), - deg_rotations=(0, 22.5, 45, 67.5, 90), - repeats=250, - num_points=25000, - ) - plt.savefig(H.make_rel_path_add_ext('plots/metric_grid_minimal_5x5', ext='.png')) - plt.show() - - # plot everything -- minimal - seed(777) - make_grid_gaussian_score_plot( - y_stds=(0.8, 0.4, 0.2, 0.05)[::-1], # (0.8, 0.4, 0.2, 0.1, 0.05), - deg_rotations=(0, 22.5, 45, 67.5, 90), - repeats=250, - num_points=25000, - ) - plt.savefig(H.make_rel_path_add_ext('plots/metric_grid_minimal_4x5', ext='.png')) - plt.show() - - # plot everything -- minimal - seed(777) - fig, axs = make_grid_gaussian_score_plot( - y_stds=(0.8, 0.2, 0.05)[::-1], # (0.8, 0.4, 0.2, 0.1, 0.05), - deg_rotations=(0, 22.5, 45, 67.5, 90), - repeats=250, - num_points=25000, - ) - plt.savefig(H.make_rel_path_add_ext('plots/metric_grid_minimal_3x5', ext='.png')) - plt.show() - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - - seed(777) - make_ave_scores_plot(repeats=250, num_points=10000, use_max=False) - plt.savefig(H.make_rel_path_add_ext('plots/metric_scores', ext='.png')) - plt.show() - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # diff --git a/research/e08_autoencoders/submit_01.sh b/research/e08_autoencoders/submit_01.sh deleted file mode 100644 index 8e5086a5..00000000 --- a/research/e08_autoencoders/submit_01.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="final-08__autoencoder-versions" -export PARTITION="stampede" -export PARALLELISM=32 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# 1 * (2*2*3*3*8) == 288 -submit_sweep \ - +DUMMY.repeat=1 \ - +EXTRA.tags='various-auto-encoders' \ - \ - run_length=short,long \ - schedule=adavae_up_ratio_full,adavae_up_all_full,none \ - \ - dataset=xysquares,cars3d,shapes3d \ - framework=ae,tae,X--adaae,X--adanegtae,vae,tvae,adavae,X--adanegtvae \ - model=conv64alt \ - model.z_size=25 \ - \ - sampling=gt_dist_manhat,gt_dist_manhat_scaled diff --git a/research/e09_vae_overlap_loss/submit_overlap_loss.sh b/research/e09_vae_overlap_loss/submit_overlap_loss.sh deleted file mode 100644 index d51572b2..00000000 --- a/research/e09_vae_overlap_loss/submit_overlap_loss.sh +++ /dev/null @@ -1,125 +0,0 @@ -#!/bin/bash - -# OVERVIEW: -# - this experiment is designed to test how changing the reconstruction loss to match the -# ground-truth distances allows datasets to be disentangled. - - -# OUTCOMES: -# - When the reconstruction loss is used as a distance function between observations, and those -# distances match the ground truth, it enables disentanglement. -# - Loss must still be able to reconstruct the inputs correctly. -# - AEs have no incentive to learn the same distances as VAEs - - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export USERNAME="n_michlo" -export PROJECT="CVPR-09__vae_overlap_loss" -export PARTITION="stampede" -export PARALLELISM=28 - -# source the helper file -source "$(dirname "$(dirname "$(realpath -s "$0")")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cudaless_nodes "$PARTITION" 86400 "C-disent" # 24 hours - -# TEST MSE vs BoxBlur MSE (with different beta values over different datasets) -# - mse boxblur weight is too strong, need to lower significantly -# 1 * (5 * 2*4*2) = 80 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='sweep_overlap_boxblur' \ -# hydra.job.name="ovlp_loss" \ -# \ -# +VAR.recon_loss_weight=1.0 \ -# +VAR.kernel_loss_weight=3969.0 \ -# +VAR.kernel_radius=31 \ -# \ -# run_length=medium \ -# metrics=all \ -# \ -# dataset=X--xysquares,dsprites,shapes3d,smallnorb,cars3d \ -# \ -# framework=betavae,adavae_os \ -# settings.framework.beta=0.0316,0.316,0.1,0.01 \ -# settings.model.z_size=25,9 \ -# settings.framework.recon_loss='mse_box_r${VAR.kernel_radius}_l${VAR.recon_loss_weight}_k${VAR.kernel_loss_weight}' \ -# \ -# sampling=default__bb - - -# TEST MSE vs BoxBlur MSE -# - changing the reconstruction loss enables disentanglement -# 5 * (2*2*2 = 8) = 40 -submit_sweep \ - +DUMMY.repeat=1,2,3,4,5 \ - +EXTRA.tags='sweep_overlap_boxblur_specific' \ - hydra.job.name="s_ovlp_loss" \ - \ - +VAR.recon_loss_weight=1.0 \ - +VAR.kernel_loss_weight=3969.0 \ - +VAR.kernel_radius=31 \ - \ - run_length=medium \ - metrics=all \ - \ - dataset=X--xysquares \ - \ - framework=betavae,adavae_os \ - settings.framework.beta=0.0316,0.0001 \ - settings.model.z_size=25 \ - settings.framework.recon_loss=mse,'mse_box_r${VAR.kernel_radius}_l${VAR.recon_loss_weight}_k${VAR.kernel_loss_weight}' \ - \ - sampling=default__bb - - -# TEST DISTANCES IN AEs VS VAEs -# -- supplementary material -# 3 * (1 * 2 = 2) = 6 -submit_sweep \ - +DUMMY.repeat=1,2,3 \ - +EXTRA.tags='sweep_overlap_boxblur_autoencoders' \ - hydra.job.name="e_ovlp_loss" \ - \ - +VAR.recon_loss_weight=1.0 \ - +VAR.kernel_loss_weight=3969.0 \ - +VAR.kernel_radius=31 \ - \ - run_length=medium \ - metrics=all \ - \ - dataset=X--xysquares \ - \ - framework=ae \ - settings.framework.beta=0.0001 \ - settings.model.z_size=25 \ - settings.framework.recon_loss=mse,'mse_box_r${VAR.kernel_radius}_l${VAR.recon_loss_weight}_k${VAR.kernel_loss_weight}' \ - \ - sampling=default__bb - - -# HPARAM SWEEP -- TODO: update -# -- old, unused -# 1 * (2 * 8 * 2 * 2) = 160 -#submit_sweep \ -# +DUMMY.repeat=1 \ -# +EXTRA.tags='sweep_beta' \ -# hydra.job.name="vae_hparams" \ -# \ -# run_length=long \ -# metrics=all \ -# \ -# settings.framework.beta=0.000316,0.001,0.00316,0.01,0.0316,0.1,0.316,1.0 \ -# framework=betavae,adavae_os \ -# schedule=none \ -# settings.model.z_size=9,25 \ -# \ -# dataset=X--xysquares \ -# sampling=default__bb diff --git a/research/gadfly.mplstyle b/research/gadfly.mplstyle deleted file mode 100644 index 00b215ac..00000000 --- a/research/gadfly.mplstyle +++ /dev/null @@ -1,627 +0,0 @@ -#### MATPLOTLIBRC FORMAT - -# FROM: https://towardsdatascience.com/a-new-plot-theme-for-matplotlib-gadfly-2cffc745ff84 - -## This is a sample matplotlib configuration file - you can find a copy -## of it on your system in -## site-packages/matplotlib/mpl-data/matplotlibrc. If you edit it -## there, please note that it will be overwritten in your next install. -## If you want to keep a permanent local copy that will not be -## overwritten, place it in the following location: -## unix/linux: -## $HOME/.config/matplotlib/matplotlibrc or -## $XDG_CONFIG_HOME/matplotlib/matplotlibrc (if $XDG_CONFIG_HOME is set) -## other platforms: -## $HOME/.matplotlib/matplotlibrc -## -## See http://matplotlib.org/users/customizing.html#the-matplotlibrc-file for -## more details on the paths which are checked for the configuration file. -## -## This file is best viewed in a editor which supports python mode -## syntax highlighting. Blank lines, or lines starting with a comment -## symbol, are ignored, as are trailing comments. Other lines must -## have the format -## key : val ## optional comment -## -## Colors: for the color values below, you can either use - a -## matplotlib color string, such as r, k, or b - an rgb tuple, such as -## (1.0, 0.5, 0.0) - a hex string, such as ff00ff - a scalar -## grayscale intensity such as 0.75 - a legal html color name, e.g., red, -## blue, darkslategray - -##### CONFIGURATION BEGINS HERE - -## The default backend; one of GTK3Agg GTK3Cairo MacOSX Qt4Agg Qt5Agg TkAgg -## WX WXAgg Agg Cairo PS PDF SVG Template. -## You can also deploy your own backend outside of matplotlib by -## referring to the module name (which must be in the PYTHONPATH) as -## 'module://my_backend'. -## -## If you omit this parameter, the backend will be determined by fallback. -#backend : Agg - -## Note that this can be overridden by the environment variable -## QT_API used by Enthought Tool Suite (ETS); valid values are -## "pyqt" and "pyside". The "pyqt" setting has the side effect of -## forcing the use of Version 2 API for QString and QVariant. - -## The port to use for the web server in the WebAgg backend. -#webagg.port : 8988 - -## The address on which the WebAgg web server should be reachable -#webagg.address : 127.0.0.1 - -## If webagg.port is unavailable, a number of other random ports will -## be tried until one that is available is found. -#webagg.port_retries : 50 - -## When True, open the webbrowser to the plot that is shown -#webagg.open_in_browser : True - -## if you are running pyplot inside a GUI and your backend choice -## conflicts, we will automatically try to find a compatible one for -## you if backend_fallback is True -#backend_fallback: True - -#interactive : False -#toolbar : toolbar2 ## None | toolbar2 ("classic" is deprecated) -#timezone : UTC ## a pytz timezone string, e.g., US/Central or Europe/Paris - -## Where your matplotlib data lives if you installed to a non-default -## location. This is where the matplotlib fonts, bitmaps, etc reside -#datapath : /home/jdhunter/mpldata - - -#### LINES -## See http://matplotlib.org/api/artist_api.html#module-matplotlib.lines for more -## information on line properties. -lines.linewidth : 2 ## line width in points -#lines.linestyle : - ## solid line -#lines.color : C0 ## has no affect on plot(); see axes.prop_cycle -#lines.marker : None ## the default marker -# lines.markerfacecolor : auto ## the default markerfacecolor -lines.markeredgecolor : white ## the default markeredgecolor -lines.markeredgewidth : 1 ## the line width around the marker symbol -lines.markersize : 7 ## markersize, in points -#lines.dash_joinstyle : round ## miter|round|bevel -#lines.dash_capstyle : butt ## butt|round|projecting -#lines.solid_joinstyle : round ## miter|round|bevel -#lines.solid_capstyle : projecting ## butt|round|projecting -#lines.antialiased : True ## render lines in antialiased (no jaggies) - -## The three standard dash patterns. These are scaled by the linewidth. -#lines.dashed_pattern : 3.7, 1.6 -#lines.dashdot_pattern : 6.4, 1.6, 1, 1.6 -#lines.dotted_pattern : 1, 1.65 -#lines.scale_dashes : True - -#markers.fillstyle: full ## full|left|right|bottom|top|none - -#### PATCHES -## Patches are graphical objects that fill 2D space, like polygons or -## circles. See -## http://matplotlib.org/api/artist_api.html#module-matplotlib.patches -## information on patch properties -patch.linewidth : 1 ## edge width in points. -patch.facecolor : C0 -patch.edgecolor : black ## if forced, or patch is not filled -#patch.force_edgecolor : False ## True to always use edgecolor -#patch.antialiased : True ## render patches in antialiased (no jaggies) - -#### HATCHES -#hatch.color : black -#hatch.linewidth : 1.0 - -#### Boxplot -#boxplot.notch : False -#boxplot.vertical : True -#boxplot.whiskers : 1.5 -# boxplot.bootstrap : None -boxplot.patchartist : True -#boxplot.showmeans : False -#boxplot.showcaps : True -#boxplot.showbox : True -#boxplot.showfliers : True -#boxplot.meanline : False - -boxplot.flierprops.color : C0 -boxplot.flierprops.marker : o -boxplot.flierprops.markerfacecolor : auto -boxplot.flierprops.markeredgecolor : white -boxplot.flierprops.markersize : 7 -boxplot.flierprops.linestyle : none -boxplot.flierprops.linewidth : 1.0 - -boxplot.boxprops.color : 9ae1f9 -boxplot.boxprops.linewidth : 0 -boxplot.boxprops.linestyle : - - -boxplot.whiskerprops.color : C0 -boxplot.whiskerprops.linewidth : 1.0 -boxplot.whiskerprops.linestyle : - - -boxplot.capprops.color : C0 -boxplot.capprops.linewidth : 1.0 -boxplot.capprops.linestyle : - - -boxplot.medianprops.color : 9ae1f9 -boxplot.medianprops.linewidth : 1 -boxplot.medianprops.linestyle : - - -boxplot.meanprops.color : C1 -boxplot.meanprops.marker : ^ -boxplot.meanprops.markerfacecolor : C1 -boxplot.meanprops.markeredgecolor : C1 -boxplot.meanprops.markersize : 7 -boxplot.meanprops.linestyle : -- -boxplot.meanprops.linewidth : 1.0 - - -#### FONT - -## font properties used by text.Text. See -## http://matplotlib.org/api/font_manager_api.html for more -## information on font properties. The 6 font properties used for font -## matching are given below with their default values. -## -## The font.family property has five values: 'serif' (e.g., Times), -## 'sans-serif' (e.g., Helvetica), 'cursive' (e.g., Zapf-Chancery), -## 'fantasy' (e.g., Western), and 'monospace' (e.g., Courier). Each of -## these font families has a default list of font names in decreasing -## order of priority associated with them. When text.usetex is False, -## font.family may also be one or more concrete font names. -## -## The font.style property has three values: normal (or roman), italic -## or oblique. The oblique style will be used for italic, if it is not -## present. -## -## The font.variant property has two values: normal or small-caps. For -## TrueType fonts, which are scalable fonts, small-caps is equivalent -## to using a font size of 'smaller', or about 83%% of the current font -## size. -## -## The font.weight property has effectively 13 values: normal, bold, -## bolder, lighter, 100, 200, 300, ..., 900. Normal is the same as -## 400, and bold is 700. bolder and lighter are relative values with -## respect to the current weight. -## -## The font.stretch property has 11 values: ultra-condensed, -## extra-condensed, condensed, semi-condensed, normal, semi-expanded, -## expanded, extra-expanded, ultra-expanded, wider, and narrower. This -## property is not currently implemented. -## -## The font.size property is the default font size for text, given in pts. -## 10 pt is the standard value. - -#font.family : sans-serif -#font.style : normal -#font.variant : normal -#font.weight : normal -#font.stretch : normal -## note that font.size controls default text sizes. To configure -## special text sizes tick labels, axes, labels, title, etc, see the rc -## settings for axes and ticks. Special text sizes can be defined -## relative to font.size, using the following values: xx-small, x-small, -## small, medium, large, x-large, xx-large, larger, or smaller -#font.size : 10.0 -#font.serif : DejaVu Serif, Bitstream Vera Serif, Computer Modern Roman, New Century Schoolbook, Century Schoolbook L, Utopia, ITC Bookman, Bookman, Nimbus Roman No9 L, Times New Roman, Times, Palatino, Charter, serif -#font.sans-serif : DejaVu Sans, Bitstream Vera Sans, Computer Modern Sans Serif, Lucida Grande, Verdana, Geneva, Lucid, Arial, Helvetica, Avant Garde, sans-serif -#font.cursive : Apple Chancery, Textile, Zapf Chancery, Sand, Script MT, Felipa, cursive -#font.fantasy : Comic Sans MS, Chicago, Charcoal, ImpactWestern, Humor Sans, xkcd, fantasy -#font.monospace : DejaVu Sans Mono, Bitstream Vera Sans Mono, Computer Modern Typewriter, Andale Mono, Nimbus Mono L, Courier New, Courier, Fixed, Terminal, monospace - -#### TEXT -## text properties used by text.Text. See -## http://matplotlib.org/api/artist_api.html#module-matplotlib.text for more -## information on text properties -text.color : 707074 - -#### LaTeX customizations. See http://wiki.scipy.org/Cookbook/Matplotlib/UsingTex -#text.usetex : False ## use latex for all text handling. The following fonts - ## are supported through the usual rc parameter settings: - ## new century schoolbook, bookman, times, palatino, - ## zapf chancery, charter, serif, sans-serif, helvetica, - ## avant garde, courier, monospace, computer modern roman, - ## computer modern sans serif, computer modern typewriter - ## If another font is desired which can loaded using the - ## LaTeX \usepackage command, please inquire at the - ## matplotlib mailing list -#text.latex.preamble : ## IMPROPER USE OF THIS FEATURE WILL LEAD TO LATEX FAILURES - ## AND IS THEREFORE UNSUPPORTED. PLEASE DO NOT ASK FOR HELP - ## IF THIS FEATURE DOES NOT DO WHAT YOU EXPECT IT TO. - ## preamble is a comma separated list of LaTeX statements - ## that are included in the LaTeX document preamble. - ## An example: - ## text.latex.preamble : \usepackage{bm},\usepackage{euler} - ## The following packages are always loaded with usetex, so - ## beware of package collisions: color, geometry, graphicx, - ## type1cm, textcomp. Adobe Postscript (PSSNFS) font packages - ## may also be loaded, depending on your font settings -#text.latex.preview : False - -#text.hinting : auto ## May be one of the following: - ## none: Perform no hinting - ## auto: Use FreeType's autohinter - ## native: Use the hinting information in the - # font file, if available, and if your - # FreeType library supports it - ## either: Use the native hinting information, - # or the autohinter if none is available. - ## For backward compatibility, this value may also be - ## True === 'auto' or False === 'none'. -#text.hinting_factor : 8 ## Specifies the amount of softness for hinting in the - ## horizontal direction. A value of 1 will hint to full - ## pixels. A value of 2 will hint to half pixels etc. -#text.antialiased : True ## If True (default), the text will be antialiased. - ## This only affects the Agg backend. - -## The following settings allow you to select the fonts in math mode. -## They map from a TeX font name to a fontconfig font pattern. -## These settings are only used if mathtext.fontset is 'custom'. -## Note that this "custom" mode is unsupported and may go away in the -## future. -#mathtext.cal : cursive -#mathtext.rm : sans -#mathtext.tt : monospace -#mathtext.it : sans:italic -#mathtext.bf : sans:bold -#mathtext.sf : sans -#mathtext.fontset : dejavusans ## Should be 'dejavusans' (default), - ## 'dejavuserif', 'cm' (Computer Modern), 'stix', - ## 'stixsans' or 'custom' -#mathtext.fallback_to_cm : True ## When True, use symbols from the Computer Modern - ## fonts when a symbol can not be found in one of - ## the custom math fonts. -#mathtext.default : it ## The default font to use for math. - ## Can be any of the LaTeX font names, including - ## the special name "regular" for the same font - ## used in regular text. - -#### AXES -## default face and edge color, default tick sizes, -## default fontsizes for ticklabels, and so on. See -## http://matplotlib.org/api/axes_api.html#module-matplotlib.axes -#axes.facecolor : white ## axes background color -axes.edgecolor : D0D0E0 ## axes edge color -#axes.linewidth : 0.8 ## edge linewidth -axes.grid : True ## display grid or not -axes.grid.axis : both ## which axis the grid should apply to -#axes.grid.which : major ## gridlines at major, minor or both ticks -axes.titlesize : 18 ## fontsize of the axes title -#axes.titleweight : normal ## font weight of title -#axes.titlepad : 6.0 ## pad between axes and title in points -axes.labelsize : 14 ## fontsize of the x any y labels -#axes.labelpad : 4.0 ## space between label and axis -#axes.labelweight : normal ## weight of the x and y labels -axes.labelcolor : 707074 -#axes.axisbelow : line ## draw axis gridlines and ticks below - ## patches (True); above patches but below - ## lines ('line'); or above all (False) -#axes.formatter.limits : -7, 7 ## use scientific notation if log10 - ## of the axis range is smaller than the - ## first or larger than the second -#axes.formatter.use_locale : False ## When True, format tick labels - ## according to the user's locale. - ## For example, use ',' as a decimal - ## separator in the fr_FR locale. -#axes.formatter.use_mathtext : False ## When True, use mathtext for scientific - ## notation. -#axes.formatter.min_exponent: 0 ## minimum exponent to format in scientific notation -#axes.formatter.useoffset : True ## If True, the tick label formatter - ## will default to labeling ticks relative - ## to an offset when the data range is - ## small compared to the minimum absolute - ## value of the data. -#axes.formatter.offset_threshold : 4 ## When useoffset is True, the offset - ## will be used when it can remove - ## at least this number of significant - ## digits from tick labels. -axes.spines.left : False ## display axis spines -axes.spines.bottom : False -axes.spines.top : False -axes.spines.right : False -#axes.unicode_minus : True ## use unicode for the minus symbol - ## rather than hyphen. See - ## http://en.wikipedia.org/wiki/Plus_and_minus_signs#Character_codes -## ========================================================================================== ## -## ========================================================================================== ## -## ========================================================================================== ## -## COLOR PALETTE -# v1 https://coolors.co/2364aa-3da5d9-4ebc93-b4da1b-fbcc23-ec8232-e40066-df26cf-ae5ce6-9b899f -# v2 https://coolors.co/3482d5-66b8e1-5cc19c-b9d548-fbc737-f2822c-ff338f-d54ee4-a072e9-9b899f -axes.prop_cycle : cycler('color', ['3482d5', '66b8e1', '5cc19c', 'b9d548', 'fbc737', 'f2822c', 'ff338f', 'd54ee4', 'a072e9', '9b899f']) ## CUSTOM -## axes.prop_cycle : cycler('color', ['00BEFF', 'D4CA3A', 'FF6DAE', '67E1B5', 'EBACFA', '9E9E9E', 'F1988E', '5DB15A', 'E28544', '52B8AA']) ## ORIG - ## color cycle for plot lines as list of string - ## colorspecs: single letter, long name, or web-style hex - ## Note the use of string escapes here ('1f77b4', instead of 1f77b4) - ## as opposed to the rest of this file. -## ========================================================================================== ## -## ========================================================================================== ## -## ========================================================================================== ## -#axes.autolimit_mode : data ## How to scale axes limits to the data. - ## Use "data" to use data limits, plus some margin - ## Use "round_number" move to the nearest "round" number -#axes.xmargin : .05 ## x margin. See `axes.Axes.margins` -#axes.ymargin : .05 ## y margin See `axes.Axes.margins` -#polaraxes.grid : True ## display grid on polar axes -#axes3d.grid : True ## display grid on 3d axes - -#### DATES -## These control the default format strings used in AutoDateFormatter. -## Any valid format datetime format string can be used (see the python -## `datetime` for details). For example using '%%x' will use the locale date representation -## '%%X' will use the locale time representation and '%%c' will use the full locale datetime -## representation. -## These values map to the scales: -## {'year': 365, 'month': 30, 'day': 1, 'hour': 1/24, 'minute': 1 / (24 * 60)} - -#date.autoformatter.year : %Y -#date.autoformatter.month : %Y-%m -#date.autoformatter.day : %Y-%m-%d -#date.autoformatter.hour : %m-%d %H -#date.autoformatter.minute : %d %H:%M -#date.autoformatter.second : %H:%M:%S -#date.autoformatter.microsecond : %M:%S.%f - -#### TICKS -## see http://matplotlib.org/api/axis_api.html#matplotlib.axis.Tick -#xtick.top : False ## draw ticks on the top side -#xtick.bottom : True ## draw ticks on the bottom side -#xtick.labeltop : False ## draw label on the top -#xtick.labelbottom : True ## draw label on the bottom -#xtick.major.size : 3.5 ## major tick size in points -#xtick.minor.size : 2 ## minor tick size in points -#xtick.major.width : 0.8 ## major tick width in points -#xtick.minor.width : 0.6 ## minor tick width in points -#xtick.major.pad : 3.5 ## distance to major tick label in points -#xtick.minor.pad : 3.4 ## distance to the minor tick label in points -xtick.color : 707074 ## color of the tick labels -xtick.labelsize : 12 ## fontsize of the tick labels -#xtick.direction : out ## direction: in, out, or inout -#xtick.minor.visible : False ## visibility of minor ticks on x-axis -#xtick.major.top : True ## draw x axis top major ticks -#xtick.major.bottom : True ## draw x axis bottom major ticks -#xtick.minor.top : True ## draw x axis top minor ticks -#xtick.minor.bottom : True ## draw x axis bottom minor ticks -#xtick.alignment : center ## alignment of xticks - -#ytick.left : True ## draw ticks on the left side -#ytick.right : False ## draw ticks on the right side -#ytick.labelleft : True ## draw tick labels on the left side -#ytick.labelright : False ## draw tick labels on the right side -#ytick.major.size : 3.5 ## major tick size in points -#ytick.minor.size : 2 ## minor tick size in points -#ytick.major.width : 0.8 ## major tick width in points -#ytick.minor.width : 0.6 ## minor tick width in points -#ytick.major.pad : 3.5 ## distance to major tick label in points -#ytick.minor.pad : 3.4 ## distance to the minor tick label in points -ytick.color : 707074 ## color of the tick labels -ytick.labelsize : 12 ## fontsize of the tick labels -#ytick.direction : out ## direction: in, out, or inout -#ytick.minor.visible : False ## visibility of minor ticks on y-axis -#ytick.major.left : True ## draw y axis left major ticks -#ytick.major.right : True ## draw y axis right major ticks -#ytick.minor.left : True ## draw y axis left minor ticks -#ytick.minor.right : True ## draw y axis right minor ticks -#ytick.alignment : center_baseline ## alignment of yticks - -#### GRIDS -grid.color : 93939c ## grid color -grid.linestyle : -- ## solid -#grid.linewidth : 0.8 ## in points -grid.alpha : 0.2 ## transparency, between 0.0 and 1.0 - -#### Legend -#legend.loc : best -#legend.frameon : True ## if True, draw the legend on a background patch -#legend.framealpha : 0.8 ## legend patch transparency -#legend.facecolor : inherit ## inherit from axes.facecolor; or color spec -#legend.edgecolor : 0.8 ## background patch boundary color -#legend.fancybox : True ## if True, use a rounded box for the - ## legend background, else a rectangle -#legend.shadow : False ## if True, give background a shadow effect -#legend.numpoints : 1 ## the number of marker points in the legend line -#legend.scatterpoints : 1 ## number of scatter points -#legend.markerscale : 1.0 ## the relative size of legend markers vs. original -#legend.fontsize : medium -#legend.title_fontsize : None ## None sets to the same as the default axes. -## Dimensions as fraction of fontsize: -#legend.borderpad : 0.4 ## border whitespace -#legend.labelspacing : 0.5 ## the vertical space between the legend entries -#legend.handlelength : 2.0 ## the length of the legend lines -#legend.handleheight : 0.7 ## the height of the legend handle -#legend.handletextpad : 0.8 ## the space between the legend line and legend text -#legend.borderaxespad : 0.5 ## the border between the axes and legend edge -#legend.columnspacing : 2.0 ## column separation - -#### FIGURE -## See http://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure -#figure.titlesize : large ## size of the figure title (Figure.suptitle()) -#figure.titleweight : normal ## weight of the figure title -#figure.figsize : 6.4, 4.8 ## figure size in inches -#figure.dpi : 100 ## figure dots per inch -#figure.facecolor : white ## figure facecolor -#figure.edgecolor : white ## figure edgecolor -#figure.frameon : True ## enable figure frame -#figure.max_open_warning : 20 ## The maximum number of figures to open through - ## the pyplot interface before emitting a warning. - ## If less than one this feature is disabled. -## The figure subplot parameters. All dimensions are a fraction of the -#figure.subplot.left : 0.125 ## the left side of the subplots of the figure -#figure.subplot.right : 0.9 ## the right side of the subplots of the figure -#figure.subplot.bottom : 0.11 ## the bottom of the subplots of the figure -#figure.subplot.top : 0.88 ## the top of the subplots of the figure -#figure.subplot.wspace : 0.2 ## the amount of width reserved for space between subplots, - ## expressed as a fraction of the average axis width -#figure.subplot.hspace : 0.2 ## the amount of height reserved for space between subplots, - ## expressed as a fraction of the average axis height - -## Figure layout -#figure.autolayout : False ## When True, automatically adjust subplot - ## parameters to make the plot fit the figure - ## using `tight_layout` -#figure.constrained_layout.use: False ## When True, automatically make plot - ## elements fit on the figure. (Not compatible - ## with `autolayout`, above). -#figure.constrained_layout.h_pad : 0.04167 ## Padding around axes objects. Float representing -#figure.constrained_layout.w_pad : 0.04167 ## inches. Default is 3./72. inches (3 pts) -#figure.constrained_layout.hspace : 0.02 ## Space between subplot groups. Float representing -#figure.constrained_layout.wspace : 0.02 ## a fraction of the subplot widths being separated. - -#### IMAGES -#image.aspect : equal ## equal | auto | a number -#image.interpolation : nearest ## see help(imshow) for options -#image.cmap : viridis ## A colormap name, gray etc... -#image.lut : 256 ## the size of the colormap lookup table -#image.origin : upper ## lower | upper -#image.resample : True -#image.composite_image : True ## When True, all the images on a set of axes are - ## combined into a single composite image before - ## saving a figure as a vector graphics file, - ## such as a PDF. - -#### CONTOUR PLOTS -#contour.negative_linestyle : dashed ## string or on-off ink sequence -#contour.corner_mask : True ## True | False | legacy - -#### ERRORBAR PLOTS -#errorbar.capsize : 0 ## length of end cap on error bars in pixels - -#### HISTOGRAM PLOTS -#hist.bins : 10 ## The default number of histogram bins. - ## If Numpy 1.11 or later is - ## installed, may also be `auto` - -#### SCATTER PLOTS -#scatter.marker : o ## The default marker type for scatter plots. - -#### Agg rendering -#### Warning: experimental, 2008/10/10 -#agg.path.chunksize : 0 ## 0 to disable; values in the range - ## 10000 to 100000 can improve speed slightly - ## and prevent an Agg rendering failure - ## when plotting very large data sets, - ## especially if they are very gappy. - ## It may cause minor artifacts, though. - ## A value of 20000 is probably a good - ## starting point. -#### PATHS -#path.simplify : True ## When True, simplify paths by removing "invisible" - ## points to reduce file size and increase rendering - ## speed -#path.simplify_threshold : 0.111111111111 ## The threshold of similarity below which - ## vertices will be removed in the - ## simplification process -#path.snap : True ## When True, rectilinear axis-aligned paths will be snapped to - ## the nearest pixel when certain criteria are met. When False, - ## paths will never be snapped. -#path.sketch : None ## May be none, or a 3-tuple of the form (scale, length, - ## randomness). - ## *scale* is the amplitude of the wiggle - ## perpendicular to the line (in pixels). *length* - ## is the length of the wiggle along the line (in - ## pixels). *randomness* is the factor by which - ## the length is randomly scaled. -#path.effects : [] ## - -#### SAVING FIGURES -## the default savefig params can be different from the display params -## e.g., you may want a higher resolution, or to make the figure -## background white -#savefig.dpi : figure ## figure dots per inch or 'figure' -#savefig.facecolor : white ## figure facecolor when saving -#savefig.edgecolor : white ## figure edgecolor when saving -#savefig.format : png ## png, ps, pdf, svg -#savefig.bbox : standard ## 'tight' or 'standard'. - ## 'tight' is incompatible with pipe-based animation - ## backends but will workd with temporary file based ones: - ## e.g. setting animation.writer to ffmpeg will not work, - ## use ffmpeg_file instead -#savefig.pad_inches : 0.1 ## Padding to be used when bbox is set to 'tight' -#savefig.jpeg_quality: 95 ## when a jpeg is saved, the default quality parameter. -#savefig.directory : ~ ## default directory in savefig dialog box, - ## leave empty to always use current working directory -#savefig.transparent : False ## setting that controls whether figures are saved with a - ## transparent background by default -#savefig.frameon : True ## enable frame of figure when saving -#savefig.orientation : portrait ## Orientation of saved figure - -### tk backend params -#tk.window_focus : False ## Maintain shell focus for TkAgg - -### ps backend params -#ps.papersize : letter ## auto, letter, legal, ledger, A0-A10, B0-B10 -#ps.useafm : False ## use of afm fonts, results in small files -#ps.usedistiller : False ## can be: None, ghostscript or xpdf - ## Experimental: may produce smaller files. - ## xpdf intended for production of publication quality files, - ## but requires ghostscript, xpdf and ps2eps -#ps.distiller.res : 6000 ## dpi -#ps.fonttype : 3 ## Output Type 3 (Type3) or Type 42 (TrueType) - -### pdf backend params -#pdf.compression : 6 ## integer from 0 to 9 - ## 0 disables compression (good for debugging) -#pdf.fonttype : 3 ## Output Type 3 (Type3) or Type 42 (TrueType) -#pdf.use14corefonts : False -#pdf.inheritcolor : False - -### svg backend params -#svg.image_inline : True ## write raster image data directly into the svg file -#svg.fonttype : path ## How to handle SVG fonts: - ## none: Assume fonts are installed on the machine where the SVG will be viewed. - ## path: Embed characters as paths -- supported by most SVG renderers - ## svgfont: Embed characters as SVG fonts -- supported only by Chrome, - ## Opera and Safari -#svg.hashsalt : None ## if not None, use this string as hash salt - ## instead of uuid4 -### pgf parameter -#pgf.rcfonts : True -#pgf.preamble : -#pgf.texsystem : xelatex - -### docstring params -##docstring.hardcopy = False ## set this when you want to generate hardcopy docstring - -## Event keys to interact with figures/plots via keyboard. -## Customize these settings according to your needs. -## Leave the field(s) empty if you don't need a key-map. (i.e., fullscreen : '') -#keymap.fullscreen : f, ctrl+f ## toggling -#keymap.home : h, r, home ## home or reset mnemonic -#keymap.back : left, c, backspace ## forward / backward keys to enable -#keymap.forward : right, v ## left handed quick navigation -#keymap.pan : p ## pan mnemonic -#keymap.zoom : o ## zoom mnemonic -#keymap.save : s, ctrl+s ## saving current figure -#keymap.help : f1 ## display help about active tools -#keymap.quit : ctrl+w, cmd+w, q ## close the current figure -#keymap.quit_all : W, cmd+W, Q ## close all figures -#keymap.grid : g ## switching on/off major grids in current axes -#keymap.grid_minor : G ## switching on/off minor grids in current axes -#keymap.yscale : l ## toggle scaling of y-axes ('log'/'linear') -#keymap.xscale : k, L ## toggle scaling of x-axes ('log'/'linear') -#keymap.all_axes : a ## enable all axes -#keymap.copy : ctrl+c, cmd+c ## Copy figure to clipboard - -###ANIMATION settings -#animation.html : none ## How to display the animation as HTML in - ## the IPython notebook. 'html5' uses - ## HTML5 video tag; 'jshtml' creates a - ## Javascript animation -#animation.writer : ffmpeg ## MovieWriter 'backend' to use -#animation.codec : h264 ## Codec to use for writing movie -#animation.bitrate: -1 ## Controls size/quality tradeoff for movie. - ## -1 implies let utility auto-determine -#animation.frame_format: png ## Controls frame format used by temp files -#animation.html_args: ## Additional arguments to pass to html writer -#animation.ffmpeg_path: ffmpeg ## Path to ffmpeg binary. Without full path - ## $PATH is searched -#animation.ffmpeg_args: ## Additional arguments to pass to ffmpeg -#animation.avconv_path: avconv ## Path to avconv binary. Without full path - ## $PATH is searched -#animation.avconv_args: ## Additional arguments to pass to avconv -#animation.convert_path: convert ## Path to ImageMagick's convert binary. - ## On Windows use the full path since convert - ## is also the name of a system tool. -#animation.convert_args: ## Additional arguments to pass to convert -#animation.embed_limit : 20.0 \ No newline at end of file diff --git a/research/helper.sh b/research/helper.sh deleted file mode 100644 index c0635998..00000000 --- a/research/helper.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Description # -# ========================================================================= # - -# before sourcing this script, it requires the following variables to be exported: -# - PROJECT: str -# - PARTITION: str -# - PARALLELISM: int - -# source this script from the script you use to run the experiment -# 1. gets and exports the path to the root -# 2. changes the working directory to the root -# 3. exports a helper function that runs the script in the background, with -# the correct python path and settings - -if [ -z "$PROJECT" ]; then echo "PROJECT is not set"; exit 1; fi -if [ -z "$PARTITION" ]; then echo "PARTITION is not set"; exit 1; fi -if [ -z "$PARALLELISM" ]; then echo "PARALLELISM is not set"; exit 1; fi -if [ -z "$USERNAME" ]; then echo "USERNAME is not set"; exit 1; fi -if [ -z "$PY_RUN_FILE" ]; then PY_RUN_FILE='experiment/run.py'; fi - -export PY_RUN_FILE - -# ========================================================================= # -# Helper # -# ========================================================================= # - -# get the root directory -SCRIPT_DIR=$(dirname "$(realpath -s "$0")") -ROOT_DIR="$(realpath -s "$SCRIPT_DIR/../..")" - -# cd into the root, exit on failure -cd "$ROOT_DIR" || exit 1 -echo "working directory is: $(pwd)" - -function submit_sweep() { - echo "SUBMITTING SWEEP:" "$@" - PYTHONPATH="$ROOT_DIR" python3 "$PY_RUN_FILE" -m \ - run_launcher=slurm \ - dsettings.launcher.partition="$PARTITION" \ - settings.job.project="$PROJECT" \ - settings.job.user="$USERNAME" \ - hydra.launcher.array_parallelism="$PARALLELISM" \ - "$@" \ - & # run in background -} - -function local_run() { - echo "RUNNING:" "$@" - PYTHONPATH="$ROOT_DIR" python3 "$PY_RUN_FILE" \ - run_launcher=local \ - settings.job.project="$PROJECT" \ - settings.job.user="$USERNAME" \ - "$@" -} - -function local_sweep() { - echo "RUNNING SWEEP:" "$@" - PYTHONPATH="$ROOT_DIR" python3 "$PY_RUN_FILE" -m \ - run_launcher=local \ - settings.job.project="$PROJECT" \ - settings.job.user="$USERNAME" \ - "$@" -} - -# export -export ROOT_DIR -export submit_sweep -export local_run - -# debug hydra -HYDRA_FULL_ERROR=1 -export HYDRA_FULL_ERROR - -# ========================================================================= # -# Slurm Helper # -# ========================================================================= # - - -function num_idle_nodes() { - if [ -z "$1" ]; then echo "partition (first arg) is not set"; exit 1; fi - # number of idle nodes - num=$(sinfo --partition="$1" --noheader -O Nodes,Available,StateCompact | awk '{if($2 == "up" && $3 == "idle"){print $1}}') - if [ -z "$num" ]; then num=0; fi - echo $num -} - -function clog_cudaless_nodes() { - if [ -z "$1" ]; then echo "partition is not set"; exit 1; fi - if [ -z "$2" ]; then echo wait=120; else wait="$2"; fi - if [ -z "$3" ]; then echo name="NO-CUDA"; else name="$3"; fi - # clog idle nodes - n=$(num_idle_nodes "$1") - if [ "$n" -lt "1" ]; then - echo -e "\e[93mclogging skipped! no idle nodes found on partition '$1'\e[0m"; - else - echo -e "\e[92mclogging $n nodes on partition '$1' for ${wait}s if cuda is not available!\e[0m"; - sbatch --array=1-"$n" --partition="$1" --job-name="$name" --output=/dev/null --error=/dev/null \ - --wrap='python -c "import torch; import time; cuda=torch.cuda.is_available(); print(\"CUDA:\", cuda, flush=True); print(flush=True); time.sleep(5 if cuda else '"$wait"');"' - fi -} - -function clog_cuda_nodes() { - if [ -z "$1" ]; then echo "partition is not set"; exit 1; fi - if [ -z "$2" ]; then echo wait=120; else wait="$2"; fi - if [ -z "$3" ]; then echo name="HAS-CUDA"; else name="$3"; fi - # clog idle nodes - n=$(num_idle_nodes "$1") - if [ "$n" -lt "1" ]; then - echo -e "\e[93mclogging skipped! no idle nodes found on partition '$1'\e[0m"; - else - echo -e "\e[92mclogging $n nodes on partition '$1' for ${wait}s if cuda is available!\e[0m"; - sbatch --array=1-"$n" --partition="$1" --job-name="$name" --output=/dev/null --error=/dev/null \ - --wrap='python -c "import torch; import time; cuda=torch.cuda.is_available(); print(\"CUDA:\", cuda, flush=True); print(flush=True); time.sleep(5 if not cuda else '"$wait"');"' - fi -} - -export num_idle_nodes -export clog_cudaless_nodes -export clog_cuda_nodes - -# ========================================================================= # -# End # -# ========================================================================= # diff --git a/research/plot_wandb_experiments/plot_experiments.py b/research/plot_wandb_experiments/plot_experiments.py deleted file mode 100644 index 6233eb38..00000000 --- a/research/plot_wandb_experiments/plot_experiments.py +++ /dev/null @@ -1,373 +0,0 @@ -import os -from typing import List -from typing import Optional - -import pandas as pd -import seaborn as sns -import wandb -from cachier import cachier as _cachier -from matplotlib import pyplot as plt -from tqdm import tqdm - -import research.util as H -from disent.util.function import wrapped_partial - - -# ========================================================================= # -# Helper # -# ========================================================================= # - - -cachier = wrapped_partial(_cachier, cache_dir='./cache') -DF = pd.DataFrame - - -def clear_cache(): - load_runs.clear_cache() - - -# ========================================================================= # -# Load WANDB Data # -# ========================================================================= # - - -@cachier() -def load_runs(project: str) -> pd.DataFrame: - api = wandb.Api() - - runs = api.runs(project) - - info_list, summary_list, config_list, name_list = [], [], [], [] - for run in tqdm(runs, desc=f'loading: {project}'): - info_list.append({ - 'id': run.id, - 'name': run.name, - 'state': run.state, - 'storage_id': run.storage_id, - 'url': run.url, - }) - summary_list.append(run.summary._json_dict) - config_list.append({k: v for k, v in run.config.items() if not k.startswith('_')}) - name_list.append(run.name) - - return pd.DataFrame({ - "info": info_list, - "summary": summary_list, - "config": config_list, - "name": name_list - }) - - -def load_expanded_runs(project: str) -> pd.DataFrame: - # load the data - df_runs: DF = load_runs(project) - # expand the dictionaries - df_info: DF = df_runs['info'].apply(pd.Series) - df_summary: DF = df_runs['summary'].apply(pd.Series) - df_config: DF = df_runs['config'].apply(pd.Series) - # merge the data - df: DF = df_config.join(df_summary).join(df_info) - assert len(df.columns) == len(df_info.columns) + len(df_summary.columns) + len(df_config.columns) - # done! - return df - - -def drop_unhashable(df: pd.DataFrame, inplace: bool = False) -> (pd.DataFrame, List[str]): - dropped = [] - for col in df.columns: - try: - df[col].unique() - except: - dropped.append(col) - df = df.drop(col, inplace=inplace, axis=1) - return df, dropped - - -def drop_non_diverse_cols(df: pd.DataFrame, inplace: bool = False) -> (pd.DataFrame, List[str]): - dropped = [] - for col in df.columns: - if len(df[col].unique()) == 1: - dropped.append(col) - df = df.drop(col, inplace=inplace, axis=1) - return df, dropped - - -# ========================================================================= # -# Prepare Data # -# ========================================================================= # - - -# common keys -K_GROUP = 'Run Group' -K_DATASET = 'Dataset' -K_FRAMEWORK = 'Framework' -K_SPACING = 'Grid Spacing' -K_BETA = 'Beta' -K_LOSS = 'Recon. Loss' -K_Z_SIZE = 'Latent Dims.' -K_REPEAT = 'Repeat' -K_STATE = 'State' -K_MIG = 'MIG Score' -K_DCI = 'DCI Score' - - -def load_general_data(project: str): - # load data - df = load_expanded_runs(project) - # filter out unneeded columns - df, dropped_hash = drop_unhashable(df) - df, dropped_diverse = drop_non_diverse_cols(df) - # rename columns - return df.rename(columns={ - 'EXTRA/tags': K_GROUP, - 'dataset/name': K_DATASET, - 'framework/name': K_FRAMEWORK, - 'dataset/data/grid_spacing': K_SPACING, - 'settings/framework/beta': K_BETA, - 'settings/framework/recon_loss': K_LOSS, - 'settings/model/z_size': K_Z_SIZE, - 'DUMMY/repeat': K_REPEAT, - 'state': K_STATE, - 'final_metric/mig.discrete_score.max': K_MIG, - 'final_metric/dci.disentanglement.max': K_DCI, - }) - - -# ========================================================================= # -# Plot Experiments # -# ========================================================================= # - -PINK = '#FE375F' -PURPLE = '#5E5BE5' -BLUE = '#0A83FE' -LBLUE = '#63D2FE' -ORANGE = '#FE9F0A' -GREEN = '#2FD157' - - -def plot_incr_overlap_exp( - rel_path: Optional[str] = None, - save: bool = True, - show: bool = True, - reg_order: int = 4, - color_betavae: str = PINK, - color_adavae: str = ORANGE, - titles: bool = False, -): - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - df = load_general_data(f'{os.environ["WANDB_USER"]}/CVPR-01__incr_overlap') - # select run groups - df = df[df[K_GROUP].isin(['sweep_xy_squares_overlap', 'sweep_xy_squares_overlap_small_beta'])] - # print common key values - print('K_GROUP: ', list(df[K_GROUP].unique())) - print('K_FRAMEWORK:', list(df[K_FRAMEWORK].unique())) - print('K_SPACING: ', list(df[K_SPACING].unique())) - print('K_BETA: ', list(df[K_BETA].unique())) - print('K_REPEAT: ', list(df[K_REPEAT].unique())) - print('K_STATE: ', list(df[K_STATE].unique())) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - BETA = 0.00316 # if grid_spacing < 6 - BETA = 0.001 # if grid_spacing >= 6 - - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - orig = df - # select runs - # df = df[df[K_STATE] == 'finished'] - # df = df[df[K_REPEAT].isin([1, 2, 3])] - # select adavae - adavae_selector = (df[K_FRAMEWORK] == 'adavae_os') & (df[K_BETA] == 0.001) # 0.001, 0.0001 - data_adavae = df[adavae_selector] - # select - betavae_selector_a = (df[K_FRAMEWORK] == 'betavae') & (df[K_BETA] == 0.001) & (df[K_SPACING] >= 3) - betavae_selector_b = (df[K_FRAMEWORK] == 'betavae') & (df[K_BETA] == 0.00316) & (df[K_SPACING] < 3) - data_betavae = df[betavae_selector_a | betavae_selector_b] - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - print('ADAGVAE', len(orig), '->', len(data_adavae)) - print('BETAVAE', len(orig), '->', len(data_betavae)) - - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - fig, axs = plt.subplots(1, 2, figsize=(10, 5)) - (ax0, ax1) = axs - # PLOT: MIG - sns.regplot(ax=ax0, x=K_SPACING, y=K_MIG, data=data_adavae, seed=777, order=reg_order, robust=False, color=color_adavae, marker='o') - sns.regplot(ax=ax0, x=K_SPACING, y=K_MIG, data=data_betavae, seed=777, order=reg_order, robust=False, color=color_betavae, marker='x', line_kws=dict(linestyle='dashed')) - ax0.legend(labels=["Ada-GVAE", "Beta-VAE"], fontsize=14) - ax0.set_ylim([-0.1, 1.1]) - ax0.set_xlim([0.8, 8.2]) - if titles: ax0.set_title('Framework Mig Scores') - # PLOT: DCI - sns.regplot(ax=ax1, x=K_SPACING, y=K_DCI, data=data_adavae, seed=777, order=reg_order, robust=False, color=color_adavae, marker='o') - sns.regplot(ax=ax1, x=K_SPACING, y=K_DCI, data=data_betavae, seed=777, order=reg_order, robust=False, color=color_betavae, marker='x', line_kws=dict(linestyle='dashed')) - ax1.legend(labels=["Ada-GVAE", "Beta-VAE"], fontsize=14) - ax1.set_ylim([-0.1, 1.1]) - ax1.set_xlim([0.8, 8.2]) - if titles: ax1.set_title('Framework DCI Scores') - # PLOT: - fig.tight_layout() - H.plt_rel_path_savefig(rel_path, save=save, show=show) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - return fig, axs - - - - - -def plot_hparams_exp( - rel_path: Optional[str] = None, - save: bool = True, - show: bool = True, - color_betavae: str = PINK, - color_adavae: str = ORANGE, -): - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - df = load_general_data(f'{os.environ["WANDB_USER"]}/CVPR-00__basic-hparam-tuning') - # select run groups - df = df[df[K_GROUP].isin(['sweep_beta'])] - # print common key values - print('K_GROUP: ', list(df[K_GROUP].unique())) - print('K_DATASET: ', list(df[K_DATASET].unique())) - print('K_FRAMEWORK:', list(df[K_FRAMEWORK].unique())) - print('K_BETA: ', list(df[K_BETA].unique())) - print('K_Z_SIZE: ', list(df[K_Z_SIZE].unique())) - print('K_REPEAT: ', list(df[K_REPEAT].unique())) - print('K_STATE: ', list(df[K_STATE].unique())) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - orig = df - # select runs - df = df[df[K_STATE] == 'finished'] - # [1.0, 0.316, 0.1, 0.0316, 0.01, 0.00316, 0.001, 0.000316] - # df = df[(0.000316 < df[K_BETA]) & (df[K_BETA] < 1.0)] - print('NUM', len(orig), '->', len(df)) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - df = df[[K_DATASET, K_FRAMEWORK, K_MIG, K_DCI]] - df[K_DATASET].replace('xysquares_minimal', 'XYSquares', inplace=True) - df[K_DATASET].replace('smallnorb', 'NORB', inplace=True) - df[K_DATASET].replace('cars3d', 'Cars3D', inplace=True) - df[K_DATASET].replace('3dshapes', 'Shapes3D', inplace=True) - df[K_DATASET].replace('dsprites', 'dSprites', inplace=True) - df[K_FRAMEWORK].replace('adavae_os', 'Ada-GVAE', inplace=True) - df[K_FRAMEWORK].replace('betavae', 'Beta-VAE', inplace=True) - PALLETTE = {'Ada-GVAE': color_adavae, 'Beta-VAE': color_betavae} - - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - fig, axs = plt.subplots(1, 2, figsize=(10, 4)) - (ax0, ax1) = axs - # PLOT: MIG - sns.violinplot(x=K_DATASET, y=K_MIG, hue=K_FRAMEWORK, palette=PALLETTE, split=True, cut=0, width=0.75, data=df, ax=ax0, scale='width', inner='quartile') - ax0.set_ylim([-0.1, 1.1]) - ax0.legend(bbox_to_anchor=(0.425, 0.9), fontsize=13) - sns.violinplot(x=K_DATASET, y=K_DCI, hue=K_FRAMEWORK, palette=PALLETTE, split=True, cut=0, width=0.75, data=df, ax=ax1, scale='width', inner='quartile') - ax1.set_ylim([-0.1, 1.1]) - ax1.get_legend().remove() - # PLOT: - fig.tight_layout() - H.plt_rel_path_savefig(rel_path, save=save, show=show, dpi=300) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - return fig, axs - - - -def plot_overlap_loss_exp( - rel_path: Optional[str] = None, - save: bool = True, - show: bool = True, - color_betavae: str = PINK, - color_adavae: str = ORANGE, - color_mse: str = '#9FD911', - color_mse_overlap: str = '#36CFC8', -): - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - df = load_general_data(f'{os.environ["WANDB_USER"]}/CVPR-09__vae_overlap_loss') - # select run groups - df = df[df[K_GROUP].isin(['sweep_overlap_boxblur_specific', 'sweep_overlap_boxblur'])] - # print common key values - print('K_GROUP: ', list(df[K_GROUP].unique())) - print() - print('K_DATASET: ', list(df[K_DATASET].unique())) - print('K_FRAMEWORK:', list(df[K_FRAMEWORK].unique())) - print('K_Z_SIZE: ', list(df[K_Z_SIZE].unique())) - print('K_LOSS: ', list(df[K_LOSS].unique())) - print('K_BETA: ', list(df[K_BETA].unique())) - print() - print('K_REPEAT: ', list(df[K_REPEAT].unique())) - print('K_STATE: ', list(df[K_STATE].unique())) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - # # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - orig = df - # select runs - df = df[df[K_STATE] == 'finished'] # TODO: update - df = df[df[K_DATASET] == 'xysquares_minimal'] - df = df[df[K_BETA].isin([0.0001, 0.0316])] - df = df[df[K_Z_SIZE] == 25] - # df = df[df[K_FRAMEWORK] == 'betavae'] # 18 - # df = df[df[K_FRAMEWORK] == 'adavae_os'] # 21 - # df = df[df[K_LOSS] == 'mse'] # 20 - # df = df[df[K_LOSS] != 'mse'] # 19 - # # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - # TEMP - # df[K_MIG] = df['final_metric/mig.discrete_score.max'] - # df[K_DCI] = df['final_metric/dci.disentanglement.max'] - - print('NUM', len(orig), '->', len(df)) - - df = df[[K_DATASET, K_FRAMEWORK, K_LOSS, K_BETA, K_MIG, K_DCI]] - df[K_DATASET].replace('xysquares_minimal', 'XYSquares', inplace=True) - df[K_FRAMEWORK].replace('adavae_os', 'Ada-GVAE', inplace=True) - df[K_FRAMEWORK].replace('betavae', 'Beta-VAE', inplace=True) - df[K_LOSS].replace('mse_box_r31_l1.0_k3969.0', 'MSE-boxblur', inplace=True) - df[K_LOSS].replace('mse', 'MSE', inplace=True) - PALLETTE = {'Ada-GVAE': color_adavae, 'Beta-VAE': color_betavae, 'MSE': color_mse, 'MSE-boxblur': color_mse_overlap} - - print(df) - - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - fig, axs = plt.subplots(1, 2, figsize=(10, 4)) - (ax0, ax1) = axs - # PLOT: MIG - sns.violinplot(x=K_FRAMEWORK, y=K_MIG, hue=K_LOSS, palette=PALLETTE, split=True, cut=0, width=0.5, data=df, ax=ax0, scale='width', inner='quartile') - ax0.set_ylim([-0.1, 1.1]) - ax0.legend(fontsize=13) - # ax0.legend(bbox_to_anchor=(0.425, 0.9), fontsize=13) - sns.violinplot(x=K_FRAMEWORK, y=K_DCI, hue=K_LOSS, palette=PALLETTE, split=True, cut=0, width=0.5, data=df, ax=ax1, scale='width', inner='quartile') - ax1.set_ylim([-0.1, 1.1]) - ax1.get_legend().remove() - # PLOT: - fig.tight_layout() - H.plt_rel_path_savefig(rel_path, save=save, show=show, dpi=300) - # ~=~=~=~=~=~=~=~=~=~=~=~=~ # - - -# ========================================================================= # -# Entrypoint # -# ========================================================================= # - - -if __name__ == '__main__': - - assert 'WANDB_USER' in os.environ, 'specify "WANDB_USER" environment variable' - - # matplotlib style - plt.style.use(os.path.join(os.path.dirname(__file__), '../gadfly.mplstyle')) - - # clear_cache() - - def main(): - # plot_hparams_exp(rel_path='plots/exp_hparams-exp', show=True) - plot_overlap_loss_exp(rel_path='plots/exp_overlap-loss', show=True) - # plot_incr_overlap_exp(rel_path='plots/exp_incr-overlap', show=True) - - main() - - -# ========================================================================= # -# DONE # -# ========================================================================= # diff --git a/research/plot_wandb_experiments/plots/.gitignore b/research/plot_wandb_experiments/plots/.gitignore deleted file mode 100644 index e33609d2..00000000 --- a/research/plot_wandb_experiments/plots/.gitignore +++ /dev/null @@ -1 +0,0 @@ -*.png diff --git a/research/util/__init__.py b/research/util/__init__.py deleted file mode 100644 index e02c2205..00000000 --- a/research/util/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ - -from ._fn_util import * -from ._dataset import * -from ._io_util import * -from ._loss import * - -# disent exports to make life easy -from disent.util.visualize.plot import to_img -from disent.util.visualize.plot import to_imgs -from disent.util.visualize.plot import plt_imshow -from disent.util.visualize.plot import plt_subplots -from disent.util.visualize.plot import plt_subplots_imshow -from disent.util.visualize.plot import plt_hide_axis -from disent.util.visualize.plot import visualize_dataset_traversal -from disent.util.visualize.plot import plt_2d_density diff --git a/research/util/_data.py b/research/util/_data.py deleted file mode 100644 index cce196e9..00000000 --- a/research/util/_data.py +++ /dev/null @@ -1,82 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -from typing import Tuple - -import numpy as np - -from disent.dataset.data import GroundTruthData -from disent.dataset.data._raw import Hdf5Dataset - - -# TODO: these classes are old... -# TODO: these classes are old... -# TODO: these classes are old... - - -class TransformDataset(GroundTruthData): - - # TODO: all data should be datasets - # TODO: file preparation should be separate from datasets - # TODO: disent/data should be datasets, and disent/datasets should be samplers that wrap disent/data - - def __init__(self, base_data: GroundTruthData, transform=None): - self.base_data = base_data - super().__init__(transform=transform) - - @property - def factor_names(self) -> Tuple[str, ...]: - return self.base_data.factor_names - - @property - def factor_sizes(self) -> Tuple[int, ...]: - return self.base_data.factor_sizes - - @property - def img_shape(self) -> Tuple[int, ...]: - return self.base_data.img_shape - - def _get_observation(self, idx): - return self.base_data[idx] - - -class AdversarialOptimizedData(TransformDataset): - - def __init__(self, h5_path: str, base_data: GroundTruthData, transform=None): - # normalize hd5f data - def _normalize_hdf5(x): - c, h, w = x.shape - if c in (1, 3): - return np.moveaxis(x, 0, -1) - return x - # get the data - self.hdf5_data = Hdf5Dataset(h5_path, transform=_normalize_hdf5) - # checks - assert isinstance(base_data, GroundTruthData), f'base_data must be an instance of {repr(GroundTruthData.__name__)}, got: {repr(base_data)}' - assert len(base_data) == len(self.hdf5_data), f'length of base_data: {len(base_data)} does not match length of hd5f data: {len(self.hdf5_data)}' - # initialize - super().__init__(base_data=base_data, transform=transform) - - def _get_observation(self, idx): - return self.hdf5_data[idx] diff --git a/research/util/_dataset.py b/research/util/_dataset.py deleted file mode 100644 index 8a30e174..00000000 --- a/research/util/_dataset.py +++ /dev/null @@ -1,457 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import os -import warnings -from typing import List -from typing import Literal -from typing import Optional -from typing import Sequence -from typing import Sized -from typing import Tuple -from typing import Union - -import numpy as np -import torch -import torch.utils.data - -from disent.dataset import DisentDataset -from disent.dataset.data import Cars3dData -from disent.dataset.data import DSpritesData -from disent.dataset.data import DSpritesImagenetData -from disent.dataset.data import GroundTruthData -from disent.dataset.data import Shapes3dData -from disent.dataset.data import SmallNorbData -from disent.dataset.data import XColumnsData -from disent.dataset.data import XYBlocksData -from disent.dataset.data import XYObjectData -from disent.dataset.data import XYSquaresData -from disent.dataset.sampling import BaseDisentSampler -from disent.dataset.sampling import GroundTruthSingleSampler -from disent.dataset.transform import Noop -from disent.dataset.transform import ToImgTensorF32 -from disent.dataset.transform import ToImgTensorU8 - - -# ========================================================================= # -# dataset io # -# ========================================================================= # - - -# TODO: this is much faster! -# -# import psutil -# import multiprocessing as mp -# -# def copy_batch_into(src: GroundTruthData, dst: torch.Tensor, i: int, j: int): -# for k in range(i, min(j, len(dst))): -# dst[k, ...] = src[k] -# return (i, j) -# -# def load_dataset_into_memory( -# gt_data: GroundTruthData, -# workers: int = min(psutil.cpu_count(logical=False), 16), -# ) -> ArrayGroundTruthData: -# # make data and tensors -# tensor = torch.zeros(len(gt_data), *gt_data.obs_shape, dtype=gt_data[0].dtype).share_memory_() -# # compute batch size -# n = len(gt_data) -# batch_size = (n + workers - 1) // workers -# # load in batches -# with mp.Pool(processes=workers) as POOL: -# POOL.starmap( -# copy_batch_into, [ -# (gt_data, tensor, i, i + batch_size) -# for i in range(0, n, batch_size) -# ] -# ) -# # return array -# return ArrayGroundTruthData.new_like(tensor, gt_data, array_chn_is_last=False) - - -def load_dataset_into_memory(gt_data: GroundTruthData, x_shape: Optional[Tuple[int, ...]] = None, batch_size=64, num_workers=min(os.cpu_count(), 16), dtype=torch.float32, raw_array=False): - assert dtype in {torch.float16, torch.float32} - # TODO: this should be part of disent? - from torch.utils.data import DataLoader - from tqdm import tqdm - from disent.dataset.data import ArrayGroundTruthData - # get observation shape - # - manually specify this if the gt_data has a transform applied that resizes the observations for example! - if x_shape is None: - x_shape = gt_data.x_shape - # load dataset into memory manually! - data = torch.zeros(len(gt_data), *x_shape, dtype=dtype) - # load all batches - dataloader = DataLoader(gt_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False) - idx = 0 - for batch in tqdm(dataloader, desc='loading dataset into memory'): - data[idx:idx+len(batch)] = batch.to(dtype) - idx += len(batch) - # done! - if raw_array: - return data - else: - # channels get swapped by the below ToImgTensorF32(), maybe allow `array_chn_is_last` as param - return ArrayGroundTruthData.new_like(array=data, gt_data=gt_data, array_chn_is_last=False) - - -# ========================================================================= # -# dataset # -# ========================================================================= # - - -TransformTypeHint = Union[Literal['uint8'], Literal['float'], Literal['float32'], Literal['none']] - - -def make_data( - name: str = 'xysquares', - factors: bool = False, - data_root: str = 'data/dataset', - try_in_memory: bool = False, - load_into_memory: bool = False, - load_memory_dtype: torch.dtype = torch.float16, - transform_mode: TransformTypeHint = 'float32' -) -> GroundTruthData: - # override values - if load_into_memory and try_in_memory: - warnings.warn('`load_into_memory==True` is incompatible with `try_in_memory==True`, setting `try_in_memory=False`!') - try_in_memory = False - # transform object - TransformCls = { - 'uint8': ToImgTensorU8, - 'float32': ToImgTensorF32, - 'none': Noop, - }[transform_mode] - # make data - if name == 'xysquares': data = XYSquaresData(transform=TransformCls()) # equivalent: [xysquares, xysquares_8x8, xysquares_8x8_s8] - elif name == 'xysquares_1x1': data = XYSquaresData(square_size=1, transform=TransformCls()) - elif name == 'xysquares_2x2': data = XYSquaresData(square_size=2, transform=TransformCls()) - elif name == 'xysquares_4x4': data = XYSquaresData(square_size=4, transform=TransformCls()) - elif name == 'xysquares_8x8': data = XYSquaresData(square_size=8, transform=TransformCls()) # 8x8x8x8x8x8 = 262144 # equivalent: [xysquares, xysquares_8x8, xysquares_8x8_s8] - elif name == 'xysquares_8x8_mini': data = XYSquaresData(square_size=8, grid_spacing=14, transform=TransformCls()) # 5x5x5x5x5x5 = 15625 - # TOY DATASETS - elif name == 'xysquares_8x8_toy': data = XYSquaresData(square_size=8, grid_spacing=8, rgb=False, num_squares=1, transform=TransformCls()) # 8x8 = ? - elif name == 'xysquares_8x8_toy_s1': data = XYSquaresData(square_size=8, grid_spacing=1, rgb=False, num_squares=1, transform=TransformCls()) # ?x? = ? - elif name == 'xysquares_8x8_toy_s2': data = XYSquaresData(square_size=8, grid_spacing=2, rgb=False, num_squares=1, transform=TransformCls()) # ?x? = ? - elif name == 'xysquares_8x8_toy_s4': data = XYSquaresData(square_size=8, grid_spacing=4, rgb=False, num_squares=1, transform=TransformCls()) # ?x? = ? - elif name == 'xysquares_8x8_toy_s8': data = XYSquaresData(square_size=8, grid_spacing=8, rgb=False, num_squares=1, transform=TransformCls()) # 8x8 = ? - # TOY DATASETS ALT - elif name == 'xcolumns_8x_toy': data = XColumnsData(square_size=8, grid_spacing=8, rgb=False, num_squares=1, transform=TransformCls()) # 8 = ? - elif name == 'xcolumns_8x_toy_s1': data = XColumnsData(square_size=8, grid_spacing=1, rgb=False, num_squares=1, transform=TransformCls()) # ? = ? - elif name == 'xcolumns_8x_toy_s2': data = XColumnsData(square_size=8, grid_spacing=2, rgb=False, num_squares=1, transform=TransformCls()) # ? = ? - elif name == 'xcolumns_8x_toy_s4': data = XColumnsData(square_size=8, grid_spacing=4, rgb=False, num_squares=1, transform=TransformCls()) # ? = ? - elif name == 'xcolumns_8x_toy_s8': data = XColumnsData(square_size=8, grid_spacing=8, rgb=False, num_squares=1, transform=TransformCls()) # 8 = ? - # OVERLAPPING DATASETS - elif name == 'xysquares_8x8_s1': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=1, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s2': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=2, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s3': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=3, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s4': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=4, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s5': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=5, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s6': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=6, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s7': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=7, transform=TransformCls()) # ?x?x?x?x?x? = ? - elif name == 'xysquares_8x8_s8': data = XYSquaresData(square_size=8, grid_size=8, grid_spacing=8, transform=TransformCls()) # 8x8x8x8x8x8 = 262144 # equivalent: [xysquares, xysquares_8x8, xysquares_8x8_s8] - # OTHER SYNTHETIC DATASETS - elif name == 'xyobject': data = XYObjectData(transform=TransformCls()) - elif name == 'xyblocks': data = XYBlocksData(transform=TransformCls()) - # NORMAL DATASETS - elif name == 'cars3d': data = Cars3dData(data_root=data_root, prepare=True, transform=TransformCls(size=64)) - elif name == 'smallnorb': data = SmallNorbData(data_root=data_root, prepare=True, transform=TransformCls(size=64)) - elif name == 'shapes3d': data = Shapes3dData(data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites': data = DSpritesData(data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - # CUSTOM DATASETS - elif name == 'dsprites_imagenet_bg_100': data = DSpritesImagenetData(visibility=100, mode='bg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_bg_80': data = DSpritesImagenetData(visibility=80, mode='bg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_bg_60': data = DSpritesImagenetData(visibility=60, mode='bg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_bg_40': data = DSpritesImagenetData(visibility=40, mode='bg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_bg_20': data = DSpritesImagenetData(visibility=20, mode='bg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - # --- # - elif name == 'dsprites_imagenet_fg_100': data = DSpritesImagenetData(visibility=100, mode='fg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_fg_80': data = DSpritesImagenetData(visibility=80, mode='fg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_fg_60': data = DSpritesImagenetData(visibility=60, mode='fg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_fg_40': data = DSpritesImagenetData(visibility=40, mode='fg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - elif name == 'dsprites_imagenet_fg_20': data = DSpritesImagenetData(visibility=20, mode='fg', data_root=data_root, prepare=True, transform=TransformCls(), in_memory=try_in_memory) - # DONE - else: raise KeyError(f'invalid data name: {repr(name)}') - # load into memory - if load_into_memory: - old_data, data = data, load_dataset_into_memory(data, dtype=load_memory_dtype, x_shape=(data.img_channels, 64, 64)) - # make dataset - if factors: - raise NotImplementedError('factor returning is not yet implemented in the rewrite! this needs to be fixed!') # TODO! - return data - - -def make_dataset( - name: str = 'xysquares', - factors: bool = False, - data_root: str = 'data/dataset', - try_in_memory: bool = False, - load_into_memory: bool = False, - load_memory_dtype: torch.dtype = torch.float16, - transform_mode: TransformTypeHint = 'float32', - sampler: BaseDisentSampler = None, -) -> DisentDataset: - data = make_data( - name=name, - factors=factors, - data_root=data_root, - try_in_memory=try_in_memory, - load_into_memory=load_into_memory, - load_memory_dtype=load_memory_dtype, - transform_mode=transform_mode, - ) - return DisentDataset( - data, - sampler=GroundTruthSingleSampler() if (sampler is None) else sampler, - return_indices=True - ) - - -def get_single_batch(dataloader, cuda=True): - for batch in dataloader: - (x_targ,) = batch['x_targ'] - break - if cuda: - x_targ = x_targ.cuda() - return x_targ - - -# ========================================================================= # -# sampling helper # -# ========================================================================= # - - -# TODO: clean this up -def sample_factors(gt_data: GroundTruthData, num_obs: int = 1024, factor_mode: str = 'sample_random', factor: Union[int, str] = None): - # sample multiple random factor traversals - if factor_mode == 'sample_traversals': - assert factor is not None, f'factor cannot be None when factor_mode=={repr(factor_mode)}' - # get traversal - f_idx = gt_data.normalise_factor_idx(factor) - # generate traversals - factors = [] - for i in range((num_obs + gt_data.factor_sizes[f_idx] - 1) // gt_data.factor_sizes[f_idx]): - factors.append(gt_data.sample_random_factor_traversal(f_idx=f_idx)) - factors = np.concatenate(factors, axis=0) - elif factor_mode == 'sample_random': - factors = gt_data.sample_factors(num_obs) - else: - raise KeyError - return factors - - -# TODO: move into dataset class -def sample_batch_and_factors(dataset: DisentDataset, num_samples: int, factor_mode: str = 'sample_random', factor: Union[int, str] = None, device=None): - factors = sample_factors(dataset.gt_data, num_obs=num_samples, factor_mode=factor_mode, factor=factor) - batch = dataset.dataset_batch_from_factors(factors, mode='target').to(device=device) - factors = torch.from_numpy(factors).to(dtype=torch.float32, device=device) - return batch, factors - - -# ========================================================================= # -# pair samplers # -# ========================================================================= # - - -def pair_indices_random(max_idx: int, approx_batch_size: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]: - """ - Generates pairs of indices in corresponding arrays, - returning random permutations - - considers [0, 1] and [1, 0] to be different # TODO: consider them to be the same - - never returns pairs with the same values, eg. [1, 1] - - (default) number of returned values is: `max_idx * sqrt(max_idx) / 2` -- arbitrarily chosen to scale slower than number of combinations - """ - # defaults - if approx_batch_size is None: - approx_batch_size = int(max_idx * (max_idx ** 0.5) / 2) - # sample values - idx_a, idx_b = np.random.randint(0, max_idx, size=(2, approx_batch_size)) - # remove similar - different = (idx_a != idx_b) - idx_a = idx_a[different] - idx_b = idx_b[different] - # return values - return idx_a, idx_b - - -def pair_indices_combinations(max_idx: int) -> Tuple[np.ndarray, np.ndarray]: - """ - Generates pairs of indices in corresponding arrays, - returning all combinations - - considers [0, 1] and [1, 0] to be the same, only returns one of them - - never returns pairs with the same values, eg. [1, 1] - - number of returned values is: `max_idx * (max_idx-1) / 2` - """ - # upper triangle excluding diagonal - # - similar to: `list(itertools.combinations(np.arange(len(t_idxs)), 2))` - idxs_a, idxs_b = np.triu_indices(max_idx, k=1) - return idxs_a, idxs_b - - -def pair_indices_nearby(max_idx: int) -> Tuple[np.ndarray, np.ndarray]: - """ - Generates pairs of indices in corresponding arrays, - returning nearby combinations - - considers [0, 1] and [1, 0] to be the same, only returns one of them - - never returns pairs with the same values, eg. [1, 1] - - number of returned values is: `max_idx` - """ - idxs_a = np.arange(max_idx) # eg. [0 1 2 3 4 5] - idxs_b = np.roll(idxs_a, shift=1, axis=0) # eg. [1 2 3 4 5 0] - return idxs_a, idxs_b - - -_PAIR_INDICES_FNS = { - 'random': pair_indices_random, - 'combinations': pair_indices_combinations, - 'nearby': pair_indices_nearby, -} - - -def pair_indices(max_idx: int, mode: str) -> Tuple[np.ndarray, np.ndarray]: - try: - fn = _PAIR_INDICES_FNS[mode] - except: - raise KeyError(f'invalid mode: {repr(mode)}') - return fn(max_idx=max_idx) - - -# ========================================================================= # -# mask helper # -# ========================================================================= # - - -def make_changed_mask(batch: torch.Tensor, masked=True): - if masked: - mask = torch.zeros_like(batch[0], dtype=torch.bool) - for i in range(len(batch)): - mask |= (batch[0] != batch[i]) - else: - mask = torch.ones_like(batch[0], dtype=torch.bool) - return mask - - -# ========================================================================= # -# dataset indices # -# ========================================================================= # - - -def sample_unique_batch_indices(num_obs: int, num_samples: int) -> np.ndarray: - assert num_obs >= num_samples, 'not enough values to sample' - assert (num_obs - num_samples) / num_obs > 0.5, 'this method might be inefficient' - # get random sample - indices = set() - while len(indices) < num_samples: - indices.update(np.random.randint(low=0, high=num_obs, size=num_samples - len(indices))) - # make sure indices are randomly ordered - indices = np.fromiter(indices, dtype=int) - # indices = np.array(list(indices), dtype=int) - np.random.shuffle(indices) - # return values - return indices - - -def generate_epoch_batch_idxs(num_obs: int, num_batches: int, mode: str = 'shuffle') -> List[np.ndarray]: - """ - Generate `num_batches` batches of indices. - - Each index is in the range [0, num_obs). - - If num_obs is not divisible by num_batches, then batches may not all be the same size. - - eg. [0, 1, 2, 3, 4] -> [[0, 1], [2, 3], [4]] -- num_obs=5, num_batches=3, sample_mode='range' - eg. [0, 1, 2, 3, 4] -> [[1, 4], [2, 0], [3]] -- num_obs=5, num_batches=3, sample_mode='shuffle' - eg. [0, 1, 0, 3, 2] -> [[0, 1], [0, 3], [2]] -- num_obs=5, num_batches=3, sample_mode='random' - """ - # generate indices - if mode == 'range': - idxs = np.arange(num_obs) - elif mode == 'shuffle': - idxs = np.arange(num_obs) - np.random.shuffle(idxs) - elif mode == 'random': - idxs = np.random.randint(0, num_obs, size=(num_obs,)) - else: - raise KeyError(f'invalid mode={repr(mode)}') - # return batches - return np.array_split(idxs, num_batches) - - -def generate_epochs_batch_idxs(num_obs: int, num_epochs: int, num_epoch_batches: int, mode: str = 'shuffle') -> List[np.ndarray]: - """ - Like generate_epoch_batch_idxs, but concatenate the batches of calling the function `num_epochs` times. - - The total number of batches returned is: `num_epochs * num_epoch_batches` - """ - batches = [] - for i in range(num_epochs): - batches.extend(generate_epoch_batch_idxs(num_obs=num_obs, num_batches=num_epoch_batches, mode=mode)) - return batches - - -# ========================================================================= # -# Dataloader Sampler Utilities # -# ========================================================================= # - - -class StochasticSampler(torch.utils.data.Sampler): - """ - Sample random batches, not guaranteed to be unique or cover the entire dataset in one epoch! - """ - - def __init__(self, data_source: Union[Sized, int], batch_size: int = 128): - super().__init__(data_source) - if isinstance(data_source, int): - self._len = data_source - else: - self._len = len(data_source) - self._batch_size = batch_size - assert isinstance(self._len, int) - assert self._len > 0 - assert isinstance(self._batch_size, int) - assert self._batch_size > 0 - - def __iter__(self): - while True: - yield from np.random.randint(0, self._len, size=self._batch_size) - - -def yield_dataloader(dataloader: torch.utils.data.DataLoader, steps: int): - i = 0 - while True: - for it in dataloader: - yield it - i += 1 - if i >= steps: - return - - -def StochasticBatchSampler(data_source: Union[Sized, int], batch_size: int): - return torch.utils.data.BatchSampler( - sampler=StochasticSampler(data_source=data_source, batch_size=batch_size), - batch_size=batch_size, - drop_last=True - ) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/util/_fn_util.py b/research/util/_fn_util.py deleted file mode 100644 index 471bd3fe..00000000 --- a/research/util/_fn_util.py +++ /dev/null @@ -1,114 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import inspect -from typing import Sequence - -from disent.util.deprecate import deprecated - - -# ========================================================================= # -# Function Arguments # -# ========================================================================= # - - -def _get_fn_from_stack(fn_name: str, stack): - # -- do we actually need all of this? - fn = None - for s in stack: - if fn_name in s.frame.f_locals: - fn = s.frame.f_locals[fn_name] - break - if fn is None: - raise RuntimeError(f'could not retrieve function: {repr(fn_name)} from call stack.') - return fn - - -@deprecated('function uses bad mechanics, see commented implementation below') -def get_caller_params(sort: bool = False, exclude: Sequence[str] = None) -> dict: - stack = inspect.stack() - fn_name = stack[1].function - fn_locals = stack[1].frame.f_locals - # get function and params - fn = _get_fn_from_stack(fn_name, stack) - fn_params = inspect.getfullargspec(fn).args - # check excluded - exclude = set() if (exclude is None) else set(exclude) - fn_params = [p for p in fn_params if (p not in exclude)] - # sort values - if sort: - fn_params = sorted(fn_params) - # return dict - return { - k: fn_locals[k] for k in fn_params - } - - -def params_as_string(params: dict, sep: str = '_', names: bool = False): - # get strings - if names: - return sep.join(f"{k}={v}" for k, v in params.items()) - else: - return sep.join(f"{v}" for k, v in params.items()) - - -# ========================================================================= # -# END # -# ========================================================================= # - - -# TODO: replace function above -# -# class DELETED(object): -# def __str__(self): return '' -# def __repr__(self): return str(self) -# -# -# DELETED = DELETED() -# -# -# def get_hparams(exclude: Union[Sequence[str], Set[str]] = None): -# # check values -# if exclude is None: -# exclude = {} -# else: -# exclude = set(exclude) -# # get frame and values -# args = inspect.getargvalues(frame=inspect.currentframe().f_back) -# # sort values -# arg_names = list(args.args) -# if args.varargs is not None: arg_names.append(args.varargs) -# if args.keywords is not None: arg_names.append(args.keywords) -# # filter values -# from argparse import Namespace -# return Namespace(**{ -# name: args.locals.get(name, DELETED) -# for name in arg_names -# if (name not in exclude) -# }) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/util/_io_util.py b/research/util/_io_util.py deleted file mode 100644 index 3c5e8ca8..00000000 --- a/research/util/_io_util.py +++ /dev/null @@ -1,239 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import base64 -import dataclasses -import inspect -import io -import os -from typing import Optional -from typing import Union - -import torch - -from disent.util.inout.paths import ensure_parent_dir_exists - - -# ========================================================================= # -# Github Upload Utility Functions # -# ========================================================================= # - - -def gh_get_repo(repo: str = None): - from github import Github - # get token str - token = os.environ.get('GITHUB_TOKEN', '') - if not token.strip(): - raise ValueError('`GITHUB_TOKEN` env variable has not been set!') - assert isinstance(token, str) - # get repo str - if repo is None: - repo = os.environ.get('GITHUB_REPO', '') - if not repo.strip(): - raise ValueError('`GITHUB_REPO` env variable has not been set!') - assert isinstance(repo, str) - # get repo - return Github(token).get_repo(repo) - - -def gh_get_branch(repo: 'Repository', branch: str = None, source_branch: str = None, allow_new_branch: bool = True) -> 'Branch': - from github import GithubException - # check branch - assert isinstance(branch, str) or (branch is None) - assert isinstance(source_branch, str) or (source_branch is None) - # get default branch - if branch is None: - branch = repo.default_branch - # retrieve branch - try: - return repo.get_branch(branch) - except GithubException as e: - if not allow_new_branch: - raise RuntimeError(f'Creating branch disabled, set `allow_new_branch=True`: {repr(branch)}') - print(f'Creating missing branch: {repr(branch)}') - sb = repo.get_branch(repo.default_branch if (source_branch is None) else source_branch) - repo.create_git_ref(ref='refs/heads/' + branch, sha=sb.commit.sha) - return repo.get_branch(branch) - - -@dataclasses.dataclass -class WriteResult: - commit: 'Commit' - content: 'ContentFile' - - -def gh_write_file(repo: 'Repository', path: str, content: Union[str, bytes], branch: str = None, allow_new_file=True, allow_overwrite_file=False, allow_new_branch=True) -> WriteResult: - from github import UnknownObjectException - # get branch - branch = gh_get_branch(repo, branch, allow_new_branch=allow_new_branch).name - # check that the file exists - try: - sha = repo.get_contents(path, ref=branch).sha - except UnknownObjectException: - sha = None - # handle file exists or not - if sha is None: - if not allow_new_file: - raise RuntimeError(f'Creating file disabled, set `allow_new_file=True`: {repr(path)}') - result = repo.create_file(path=path, message=f'Created File: {path}', content=content, branch=branch) - else: - if not allow_overwrite_file: - raise RuntimeError(f'Overwriting file disabled, `set allow_overwrite_file=True`: {repr(path)}') - result = repo.update_file(path=path, message=f'Updated File: {path}', content=content, branch=branch, sha=sha) - # result is a dict: {'commit': github.Commit, 'content': github.ContentFile} - return WriteResult(**result) - - -# ========================================================================= # -# Github Upload Utility Class # -# ========================================================================= # - - -class GithubWriter(object): - - def __init__(self, repo: str = None, branch: str = None, allow_new_file=True, allow_overwrite_file=True, allow_new_branch=True): - self._kwargs = dict( - repo=gh_get_repo(repo=repo), - branch=branch, - allow_new_file=allow_new_file, - allow_overwrite_file=allow_overwrite_file, - allow_new_branch=allow_new_branch, - ) - - def write_file(self, path: str, content: Union[str, bytes]): - return gh_write_file( - path=path, - content=content, - **self._kwargs, - ) - - -# ========================================================================= # -# Torch Save Utils # -# ========================================================================= # - - -def torch_save_bytes(model) -> bytes: - buffer = io.BytesIO() - torch.save(model, buffer) - buffer.seek(0) - return buffer.read() - - -def torch_save_base64(model) -> str: - b = torch_save_bytes(model) - return base64.b64encode(b).decode('ascii') - - -def torch_load_bytes(b: bytes): - return torch.load(io.BytesIO(b)) - - -def torch_load_base64(s: str): - b = base64.b64decode(s.encode('ascii')) - return torch_load_bytes(b) - - -# ========================================================================= # -# write # -# ========================================================================= # - - -def _split_special_path(path): - if path.startswith('github:'): - # get github repo and path - path = path[len('github:'):] - repo, path = os.path.join(*path.split('/')[:2]), os.path.join(*path.split('/')[2:]) - # check paths - assert repo.strip() and len(repo.split('/')) == 2 - assert path.strip() and len(repo.split('/')) >= 1 - # return components - return 'github', (repo, path) - else: - return 'local', path - - -def torch_write(path: str, model): - path_type, path = _split_special_path(path) - # handle cases - if path_type == 'github': - path, repo = path - # get the name of the path - ghw = GithubWriter(repo) - ghw.write_file(path=path, content=torch_save_bytes(model)) - print(f'Saved in repo: {repr(path)} to file: {repr(repo)}') - elif path_type == 'local': - torch.save(model, ensure_parent_dir_exists(path)) - print(f'Saved to file: {repr(path)}') - else: - raise KeyError(f'unknown path type: {repr(path_type)}') - - -# ========================================================================= # -# Files # -# ========================================================================= # - - -def _make_rel_path(*path_segments, is_file=True, _calldepth=0): - assert not os.path.isabs(os.path.join(*path_segments)), 'path must be relative' - # get source - stack = inspect.stack() - module = inspect.getmodule(stack[_calldepth+1].frame) - reldir = os.path.dirname(module.__file__) - # make everything - path = os.path.join(reldir, *path_segments) - folder_path = os.path.dirname(path) if is_file else path - os.makedirs(folder_path, exist_ok=True) - return path - - -def _make_rel_path_add_ext(*path_segments, ext='.png', _calldepth=0): - # make path - path = _make_rel_path(*path_segments, is_file=True, _calldepth=_calldepth+1) - if not os.path.splitext(path)[1]: - path = f'{path}{ext}' - return path - - -def make_rel_path(*path_segments, is_file=True): - return _make_rel_path(*path_segments, is_file=is_file, _calldepth=1) - - -def make_rel_path_add_ext(*path_segments, ext='.png'): - return _make_rel_path_add_ext(*path_segments, ext=ext, _calldepth=1) - - -def plt_rel_path_savefig(rel_path: Optional[str], save: bool = True, show: bool = True, ext='.png', dpi: Optional[int] = None): - import matplotlib.pyplot as plt - if save and (rel_path is not None): - path = _make_rel_path_add_ext(rel_path, ext=ext, _calldepth=2) - plt.savefig(path, dpi=dpi) - print(f'saved: {repr(path)}') - if show: - plt.show() - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/util/_loss.py b/research/util/_loss.py deleted file mode 100644 index 50cbc1ab..00000000 --- a/research/util/_loss.py +++ /dev/null @@ -1,160 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - -import inspect -import warnings -from typing import Optional -from typing import Sequence - -import numpy as np -import torch -from torch.nn import functional as F - -from disent import registry -from disent.nn.loss.reduction import batch_loss_reduction - - -# ========================================================================= # -# optimizer # -# ========================================================================= # - - -_SPECIALIZATIONS = {'sgd_m': ('sgd', dict(momentum=0.1))} - - -def make_optimizer(model: torch.nn.Module, name: str = 'sgd', lr=1e-3, weight_decay: Optional[float] = None): - if isinstance(model, torch.nn.Module): - params = model.parameters() - elif isinstance(model, torch.Tensor): - assert model.requires_grad - params = [model] - else: - raise TypeError(f'cannot optimize type: {type(model)}') - # get specializations - kwargs = {} - if name in _SPECIALIZATIONS: - name, kwargs = _SPECIALIZATIONS[name] - # get optimizer class - optimizer_cls = registry.OPTIMIZERS[name] - optimizer_params = set(inspect.signature(optimizer_cls).parameters.keys()) - # add optional arguments - if weight_decay is not None: - if 'weight_decay' in optimizer_params: - kwargs['weight_decay'] = weight_decay - else: - warnings.warn(f'{name}: weight decay cannot be set, optimizer does not have `weight_decay` parameter') - # instantiate - return optimizer_cls(params, lr=lr, **kwargs) - - -def step_optimizer(optimizer, loss): - optimizer.zero_grad() - loss.backward() - optimizer.step() - - -# ========================================================================= # -# Loss # -# ========================================================================= # - - -def _unreduced_mse_loss(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: - return F.mse_loss(pred, targ, reduction='none') - - -def _unreduced_mae_loss(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: - return torch.abs(pred - targ) - - -def unreduced_loss(pred: torch.Tensor, targ: torch.Tensor, mode='mse') -> torch.Tensor: - return _LOSS_FNS[mode](pred, targ) - - -_LOSS_FNS = { - 'mse': _unreduced_mse_loss, - 'mae': _unreduced_mae_loss, -} - - -# ========================================================================= # -# Pairwise Loss # -# ========================================================================= # - - -def pairwise_loss(pred: torch.Tensor, targ: torch.Tensor, mode='mse', mean_dtype=None, mask: Optional[torch.Tensor] = None) -> torch.Tensor: - # check input - assert pred.shape == targ.shape - # mean over final dims - loss = unreduced_loss(pred=pred, targ=targ, mode=mode) - # mask values - if mask is not None: - loss *= mask - # reduce - loss = batch_loss_reduction(loss, reduction_dtype=mean_dtype, reduction='mean') - # check result - assert loss.shape == pred.shape[:1] - # done - return loss - - -def unreduced_overlap(pred: torch.Tensor, targ: torch.Tensor, mode='mse') -> torch.Tensor: - # -ve loss - return - unreduced_loss(pred=pred, targ=targ, mode=mode) - - -def pairwise_overlap(pred: torch.Tensor, targ: torch.Tensor, mode='mse', mean_dtype=None) -> torch.Tensor: - # -ve loss - return - pairwise_loss(pred=pred, targ=targ, mode=mode, mean_dtype=mean_dtype) - - -# ========================================================================= # -# Factor Distances # -# ========================================================================= # - - -def np_factor_dists( - factors_a: np.ndarray, - factors_b: np.ndarray, - factor_sizes: Optional[Sequence[int]] = None, - circular_if_factor_sizes: bool = True, - p: int = 1, -) -> np.ndarray: - assert factors_a.ndim == 2 - assert factors_a.shape == factors_b.shape - # compute factor distances - fdists = np.abs(factors_a - factors_b) # (NUM, FACTOR_SIZE) - # circular distance - if (factor_sizes is not None) and circular_if_factor_sizes: - M = np.array(factor_sizes)[None, :] # (FACTOR_SIZE,) -> (1, FACTOR_SIZE) - assert M.shape == (1, factors_a.shape[-1]) - fdists = np.where(fdists > (M // 2), M - fdists, fdists) # (NUM, FACTOR_SIZE) - # compute final dists - fdists = (fdists ** p).sum(axis=-1) ** (1 / p) - # return values - return fdists # (NUM,) - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/research/wandb_cli.py b/research/wandb_cli.py deleted file mode 100644 index a97dce8a..00000000 --- a/research/wandb_cli.py +++ /dev/null @@ -1,31 +0,0 @@ - -""" -This file is an alias to the wandb cli that first sets the -temporary directory to a different folder `/tmp//tmp`, -in case `/tmp` has been polluted or you don't have the correct -access rights to modify files. - -- I am not sure why we need to do this, it is probably a bug with - wandb (or even tempfile) not respecting the `TMPDIR`, `TEMP` and - `TMP` environment variables which when set should do the same as - below? According to the tempdir docs: - https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir -""" - -# wandb_cli.py -if __name__ == '__main__': - import os - import tempfile - - # generate the temporary directory from the user - temp_dir = f'/tmp/{os.environ["USER"]}/wandb' - print(f'[PATCHING:] tempfile.tempdir={repr(temp_dir)}') - - # we need to patch tempdir before we can import wandb - assert tempfile.tempdir is None - os.makedirs(temp_dir, exist_ok=True) - tempfile.tempdir = temp_dir - - # taken from wandb.__main__ - from wandb.cli.cli import cli - cli(prog_name="python -m wandb") diff --git a/research/working-batch.sh b/research/working-batch.sh deleted file mode 100644 index 53855652..00000000 --- a/research/working-batch.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export PROJECT="N/A" -export USERNAME="N/A" -export PARTITION="batch" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(realpath -s "$0")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cuda_nodes "$PARTITION" 43200 "W-disent" # 12 hours diff --git a/research/working-stampede.sh b/research/working-stampede.sh deleted file mode 100644 index c4d7170b..00000000 --- a/research/working-stampede.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -# ========================================================================= # -# Settings # -# ========================================================================= # - -export PROJECT="N/A" -export USERNAME="N/A" -export PARTITION="stampede" -export PARALLELISM=24 - -# source the helper file -source "$(dirname "$(realpath -s "$0")")/helper.sh" - -# ========================================================================= # -# Experiment # -# ========================================================================= # - -clog_cuda_nodes "$PARTITION" 43200 "W-disent" # 12 hours diff --git a/tests/test_data_similarity.py b/tests/test_data_similarity.py index a22c387b..33169471 100644 --- a/tests/test_data_similarity.py +++ b/tests/test_data_similarity.py @@ -26,8 +26,6 @@ from disent.dataset.data import XYObjectData from disent.dataset.data import XYObjectShadedData -from disent.dataset.data import XYSquaresData # pragma: delete-on-release -from disent.dataset.data import XYSquaresMinimalData # pragma: delete-on-release # ========================================================================= # @@ -35,18 +33,6 @@ # ========================================================================= # -def test_xysquares_similarity(): # pragma: delete-on-release - data_org = XYSquaresData() # pragma: delete-on-release - data_min = XYSquaresMinimalData() # pragma: delete-on-release - # check lengths # pragma: delete-on-release - assert len(data_org) == len(data_min) # pragma: delete-on-release - n = len(data_min) # pragma: delete-on-release - # check items # pragma: delete-on-release - for i in np.random.randint(0, n, size=100): # pragma: delete-on-release - assert np.allclose(data_org[i], data_min[i]) # pragma: delete-on-release - # check bounds # pragma: delete-on-release - assert np.allclose(data_org[0], data_min[0]) # pragma: delete-on-release - assert np.allclose(data_org[n-1], data_min[n-1]) # pragma: delete-on-release def test_xyobject_similarity(): diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 438150d1..91c1750c 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -35,9 +35,7 @@ from disent.dataset.sampling import GroundTruthPairSampler from disent.dataset.sampling import GroundTruthTripleSampler from disent.frameworks.ae import * -from disent.frameworks.ae.experimental import * # pragma: delete-on-release from disent.frameworks.vae import * -from disent.frameworks.vae.experimental import * # pragma: delete-on-release from disent.model import AutoEncoder from disent.model.ae import DecoderLinear from disent.model.ae import EncoderLinear @@ -52,16 +50,10 @@ @pytest.mark.parametrize(['Framework', 'cfg_kwargs', 'Data'], [ # AE - unsupervised (Ae, dict(), XYObjectData), - # AE - unsupervised - EXP # pragma: delete-on-release - (DataOverlapTripletAe, dict(overlap_mine_triplet_mode='hard_neg'), XYObjectData), # pragma: delete-on-release # AE - weakly supervised # - # AE - weakly supervised - EXP # pragma: delete-on-release - (AdaAe, dict(), XYObjectData), # pragma: delete-on-release # AE - supervised (TripletAe, dict(), XYObjectData), - # AE - supervised - EXP # pragma: delete-on-release - (AdaNegTripletAe, dict(), XYObjectData), # pragma: delete-on-release # VAE - unsupervised (Vae, dict(), XYObjectData), (BetaVae, dict(), XYObjectData), @@ -71,34 +63,13 @@ (DfcVae, dict(), XYObjectData), (DfcVae, dict(), partial(XYObjectData, rgb=False)), (BetaTcVae, dict(), XYObjectData), - # VAE - unsupervised - EXP # pragma: delete-on-release - (DataOverlapTripletVae,dict(overlap_mine_triplet_mode='none'), XYObjectData), # pragma: delete-on-release - (DataOverlapTripletVae,dict(overlap_mine_triplet_mode='semi_hard_neg'), XYObjectData), # pragma: delete-on-release - (DataOverlapTripletVae,dict(overlap_mine_triplet_mode='hard_neg'), XYObjectData), # pragma: delete-on-release - (DataOverlapTripletVae,dict(overlap_mine_triplet_mode='hard_pos'), XYObjectData), # pragma: delete-on-release - (DataOverlapTripletVae,dict(overlap_mine_triplet_mode='easy_pos'), XYObjectData), # pragma: delete-on-release - (DataOverlapRankVae, dict(), XYObjectData), # pragma: delete-on-release # VAE - weakly supervised (AdaVae, dict(), XYObjectData), (AdaVae, dict(ada_average_mode='ml-vae'), XYObjectData), (AdaGVaeMinimal, dict(), XYObjectData), - # VAE - weakly supervised - EXP # pragma: delete-on-release - (SwappedTargetAdaVae, dict(swap_chance=1.0), XYObjectData), # pragma: delete-on-release - (SwappedTargetBetaVae, dict(swap_chance=1.0), XYObjectData), # pragma: delete-on-release - (AugPosTripletVae, dict(), XYObjectData), # pragma: delete-on-release # VAE - supervised (TripletVae, dict(), XYObjectData), (TripletVae, dict(disable_decoder=True, disable_reg_loss=True, disable_posterior_scale=0.5), XYObjectData), - # VAE - supervised - EXP # pragma: delete-on-release - (BoundedAdaVae, dict(), XYObjectData), # pragma: delete-on-release - (GuidedAdaVae, dict(), XYObjectData), # pragma: delete-on-release - (GuidedAdaVae, dict(gada_anchor_ave_mode='thresh'), XYObjectData), # pragma: delete-on-release - (TripletBoundedAdaVae, dict(), XYObjectData), # pragma: delete-on-release - (TripletGuidedAdaVae, dict(), XYObjectData), # pragma: delete-on-release - (AdaTripletVae, dict(), XYObjectData), # pragma: delete-on-release - (AdaAveTripletVae, dict(adat_share_mask_mode='posterior'), XYObjectData), # pragma: delete-on-release - (AdaAveTripletVae, dict(adat_share_mask_mode='sample'), XYObjectData), # pragma: delete-on-release - (AdaAveTripletVae, dict(adat_share_mask_mode='sample_each'), XYObjectData), # pragma: delete-on-release ]) def test_frameworks(Framework, cfg_kwargs, Data): DataSampler = { diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5d67b4a0..92b0a9bc 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -43,8 +43,6 @@ wrapped_partial(metric_dci, num_train=7, num_test=7), wrapped_partial(metric_sap, num_train=7, num_test=7), wrapped_partial(metric_factor_vae, num_train=7, num_eval=7, num_variance_estimate=7), - wrapped_partial(metric_flatness, factor_repeats=7), # pragma: delete-on-release - wrapped_partial(metric_flatness_components, factor_repeats=7), # pragma: delete-on-release ]) def test_metrics(metric_fn): z_size = 8 diff --git a/tests/test_registry.py b/tests/test_registry.py index 399c1f62..c173d969 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -44,17 +44,6 @@ } -COUNTS = { # pragma: delete-on-release - 'DATASETS': 10, # pragma: delete-on-release - 'SAMPLERS': 8, # pragma: delete-on-release - 'FRAMEWORKS': 25, # pragma: delete-on-release - 'RECON_LOSSES': 6, # pragma: delete-on-release - 'LATENT_DISTS': 2, # pragma: delete-on-release - 'OPTIMIZERS': 30, # pragma: delete-on-release - 'METRICS': 7, # pragma: delete-on-release - 'SCHEDULES': 5, # pragma: delete-on-release - 'MODELS': 8, # pragma: delete-on-release -} # pragma: delete-on-release def test_registry_loading():