Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Nov 22, 2021
2 parents 82cd508 + 066e9af commit 8a53282
Show file tree
Hide file tree
Showing 16 changed files with 184 additions and 90 deletions.
6 changes: 5 additions & 1 deletion disent/dataset/data/_groundtruth__dsprites_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ def __getitem__(self, idx):
return np.array(img)


def _noop(x):
return x


def load_imagenet_tiny_data(raw_data_dir):
data = NumpyFolder(os.path.join(raw_data_dir, 'train'))
data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=lambda x: x)
data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=_noop)
# load data - this is a bit memory inefficient doing it like this instead of with a loop into a pre-allocated array
imgs = np.concatenate(list(tqdm(data, 'loading')), axis=0)
assert imgs.shape == (100_000, 64, 64, 3)
Expand Down
13 changes: 0 additions & 13 deletions disent/dataset/data/_groundtruth__xyobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,7 @@ def __init__(
rgb: bool = True,
palette: str = 'rainbow_4',
transform=None,
warn_: bool = True
):
if warn_:
warnings.warn(
'`XYObjectData` defaults were changed in disent v0.3.0, if you want `approx` <= v0.2.x behavior then use the following parameters. Pallets also changed slightly too.'
'\n\tgrid_size=64'
'\n\tgrid_spacing=1'
'\n\tmin_square_size=3'
'\n\tmax_square_size=9'
'\n\tsquare_size_spacing=2'
'\n\trgb=True'
'\n\tpalette="colors_1"'
)
# generation
self._rgb = rgb
# check the pallete name
Expand Down Expand Up @@ -238,7 +226,6 @@ def __init__(
rgb=rgb,
palette=f'{palette}_1',
transform=transform,
warn_=False,
)

def _get_observation(self, idx):
Expand Down
6 changes: 4 additions & 2 deletions disent/dataset/transform/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ class FftKernel(DisentModule):

def __init__(self, kernel: Union[torch.Tensor, str], normalize: bool = True):
super().__init__()
# load the kernel
self._kernel = torch.nn.Parameter(get_kernel(kernel, normalize=normalize), requires_grad=False)
# load & save the kernel -- no gradients allowed
self._kernel: torch.Tensor
self.register_buffer('_kernel', get_kernel(kernel, normalize=normalize), persistent=True)
self._kernel.requires_grad = False

def forward(self, obs):
# add or remove batch dim
Expand Down
97 changes: 70 additions & 27 deletions disent/dataset/util/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,70 +96,113 @@ def main(progress=False):
data.DSpritesData,
data.SmallNorbData,
data.Shapes3dData,
wrapped_partial(data.Mpi3dData, subset='toy', in_memory=True),
wrapped_partial(data.Mpi3dData, subset='realistic', in_memory=True),
wrapped_partial(data.Mpi3dData, subset='real', in_memory=True),
# groundtruth -- impl synthetic
data.XYObjectData,
data.XYObjectShadedData,
# large datasets
(data.Mpi3dData, dict(subset='toy', in_memory=True)),
(data.Mpi3dData, dict(subset='realistic', in_memory=True)),
(data.Mpi3dData, dict(subset='real', in_memory=True)),
]:
from disent.dataset.transform import ToImgTensorF32
# get arguments
if isinstance(data_cls, tuple):
data_cls, kwargs = data_cls
else:
data_cls, kwargs = data_cls, {}
# Most common standardized way of computing the mean and std over observations
# resized to 64px in size of dtype float32 in the range [0, 1].
data = data_cls(transform=ToImgTensorF32(size=64))
data = data_cls(transform=ToImgTensorF32(size=64), **kwargs)
mean, std = compute_data_mean_std(data, progress=progress)
# results!
print(f'{data.__class__.__name__} - {data.name}:\n mean: {mean.tolist()}\n std: {std.tolist()}')
print(f'{data.__class__.__name__} - {data.name} - {kwargs}:\n mean: {mean.tolist()}\n std: {std.tolist()}')

# RUN!
main()


# ========================================================================= #
# RESULTS: 2021-10-12 #
# RESULTS: 2021-11-12 #
# ========================================================================= #


# Cars3dData - cars3d:
# Cars3dData - cars3d - {}:
# mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
# std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
# DSpritesData - dsprites:
# DSpritesData - dsprites - {}:
# mean: [0.042494423521889584]
# std: [0.19516645880626055]
# SmallNorbData - smallnorb:
# SmallNorbData - smallnorb - {}:
# mean: [0.7520918401088603]
# std: [0.09563879016827262]
# Shapes3dData - 3dshapes:
# std: [0.09563879016827263]
# Shapes3dData - 3dshapes - {}:
# mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578]
# std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748]
# Mpi3dData - mpi3d_toy:
# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702]
# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426]
# Mpi3dData - mpi3d_realistic:
# mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256]
# std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177]
# Mpi3dData - mpi3d_real:
# mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627]
# std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629]
# XYBlocksData - xyblocks:
# std: [0.2940814043555559, 0.34439790875172144, 0.3661685981524748]

# XYBlocksData - xyblocks - {}:
# mean: [0.10040509259259259, 0.10040509259259259, 0.10040509259259259]
# std: [0.21689087652106678, 0.21689087652106676, 0.21689087652106678]
# XYObjectData - xy_object:
# XYObjectData - xy_object - {}:
# mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288]
# std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585]
# XYObjectShadedData - xy_object:
# XYObjectShadedData - xy_object - {}:
# mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288]
# std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585]
# XYSquaresData - xy_squares:
# XYSquaresData - xy_squares - {}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854]
# XYSquaresMinimalData - xy_squares:
# XYSquaresMinimalData - xy_squares_minimal - {}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854]
# XColumnsData - x_columns:
# XColumnsData - x_columns - {}:
# mean: [0.125, 0.125, 0.125]
# std: [0.33075929223788925, 0.3307592922378891, 0.3307592922378892]

# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 8}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854]
# overlap between squares for reconstruction loss, 7 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 7}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854]
# overlap between squares for reconstruction loss, 6 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 6}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920855]
# overlap between squares for reconstruction loss, 5 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 5}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920855, 0.12403473458920854]
# overlap between squares for reconstruction loss, 4 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 4}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854]
# overlap between squares for reconstruction loss, 3 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 3}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854]
# overlap between squares for reconstruction loss, 2 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 2}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854]
# overlap between squares for reconstruction loss, 1 < 8
# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 1}:
# mean: [0.015625, 0.015625, 0.015625]
# std: [0.12403473458920855, 0.12403473458920855, 0.12403473458920855]
# XYSquaresData - xy_squares - {'rgb': False}:
# mean: [0.046146392822265625]
# std: [0.2096506119375896]

# Mpi3dData - mpi3d_toy - {'subset': 'toy', 'in_memory': True}:
# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702]
# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426]
# Mpi3dData - mpi3d_realistic - {'subset': 'realistic', 'in_memory': True}:
# mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256]
# std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177]
# Mpi3dData - mpi3d_real - {'subset': 'real', 'in_memory': True}:
# mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627]
# std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629]


# ========================================================================= #
# END #
Expand Down
17 changes: 11 additions & 6 deletions disent/frameworks/helper/reconstructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def compute_unreduced_loss(self, x_recon, x_targ):

class AugmentedReconLossHandler(ReconLossHandler):

def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], kernel_weight=1.0):
def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], wrap_weight=1.0, aug_weight=1.0):
super().__init__(reduction=recon_loss_handler._reduction)
# save variables
self._recon_loss_handler = recon_loss_handler
Expand All @@ -251,16 +251,21 @@ def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torc
# load the kernel
self._kernel = FftKernel(kernel=kernel, normalize=True)
# kernel weighting
assert 0 <= kernel_weight <= 1, f'kernel weight must be in the range [0, 1] but received: {repr(kernel_weight)}'
self._kernel_weight = kernel_weight
assert 0 <= wrap_weight, f'loss_weight must be in the range [0, inf) but received: {repr(wrap_weight)}'
assert 0 <= aug_weight, f'kern_weight must be in the range [0, inf) but received: {repr(aug_weight)}'
self._wrap_weight = wrap_weight
self._aug_weight = aug_weight
# disable gradients
for param in self.parameters():
param.requires_grad = False

def activate(self, x_partial: torch.Tensor):
return self._recon_loss_handler.activate(x_partial)

def compute_unreduced_loss(self, x_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor:
aug_loss = self._recon_loss_handler.compute_unreduced_loss(self._kernel(x_recon), self._kernel(x_targ))
loss = self._recon_loss_handler.compute_unreduced_loss(x_recon, x_targ)
return (1. - self._kernel_weight) * loss + self._kernel_weight * aug_loss
wrap_loss = self._recon_loss_handler.compute_unreduced_loss(x_recon, x_targ)
aug_loss = self._recon_loss_handler.compute_unreduced_loss(self._kernel(x_recon), self._kernel(x_targ))
return (self._wrap_weight * wrap_loss) + (self._aug_weight * aug_loss)

def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor:
return self.compute_unreduced_loss(self.activate(x_partial_recon), x_targ)
Expand Down
4 changes: 4 additions & 0 deletions disent/frameworks/vae/_unsupervised__vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def do_training_step(self, batch, batch_idx):
'recon_loss': recon_loss,
'reg_loss': reg_loss,
'aug_loss': aug_loss,
# ratios
'ratio_reg': (reg_loss / loss) if (loss != 0) else 0,
'ratio_rec': (recon_loss / loss) if (loss != 0) else 0,
'ratio_aug': (aug_loss / loss) if (loss != 0) else 0,
}

# --------------------------------------------------------------------- #
Expand Down
47 changes: 30 additions & 17 deletions disent/util/lightning/callbacks/_callbacks_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,25 @@ def _to_dmat(
return dmat


_AE_DIST_NAMES = ('x', 'z_l1', 'x_recon')
_VAE_DIST_NAMES = ('x', 'z_l1', 'kl', 'x_recon')
_AE_DIST_NAMES = ('x', 'z', 'x_recon')
_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon')


@torch.no_grad()
def _get_dists_ae(ae: Ae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor):
def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor):
# feed forware
z_a, z_b = ae.encode(x_a), ae.encode(x_b)
r_a, r_b = ae.decode(z_a), ae.decode(z_b)
# distances
return [
recon_loss.compute_pairwise_loss(x_a, x_b),
torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist
recon_loss.compute_pairwise_loss(r_a, r_b),
ae.recon_handler.compute_pairwise_loss(x_a, x_b),
torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
ae.recon_handler.compute_pairwise_loss(r_a, r_b),
]


@torch.no_grad()
def _get_dists_vae(vae: Vae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor):
def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor):
from torch.distributions import kl_divergence
# feed forward
(z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b)
Expand All @@ -150,19 +150,19 @@ def _get_dists_vae(vae: Vae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_
kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a)
# distances
return [
recon_loss.compute_pairwise_loss(x_a, x_b),
torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist
recon_loss._pairwise_reduce(kl_ab),
recon_loss.compute_pairwise_loss(r_a, r_b),
vae.recon_handler.compute_pairwise_loss(x_a, x_b),
torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist
vae.recon_handler._pairwise_reduce(kl_ab),
vae.recon_handler.compute_pairwise_loss(r_a, r_b),
]


def _get_dists_fn(model, recon_loss: ReconLossHandler) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]:
def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]:
# get aggregate function
if isinstance(model, Vae):
dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model, recon_loss)
dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model)
elif isinstance(model, Ae):
dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model, recon_loss)
dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model)
else:
dists_names, dists_fn = None, None
return dists_names, dists_fn
Expand Down Expand Up @@ -303,7 +303,6 @@ def __init__(
assert traversal_repeats > 0
self._traversal_repeats = traversal_repeats
self._seed = seed
self._recon_loss = make_reconstruction_loss('mse', 'mean')
self._plt_block_size = plt_block_size
self._plt_show = plt_show
self._log_wandb = log_wandb
Expand All @@ -321,7 +320,7 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!')
return
# get aggregate function
dists_names, dists_fn = _get_dists_fn(vae, self._recon_loss)
dists_names, dists_fn = _get_dists_fn(vae)
if (dists_names is None) or (dists_fn is None):
log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}')
return
Expand Down Expand Up @@ -488,6 +487,19 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
plt.show()


def _normalized_numeric_metrics(items: dict):
results = {}
for k, v in items.items():
if isinstance(v, (float, int)):
results[k] = v
else:
try:
results[k] = float(v)
except:
log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}')
return results


class VaeMetricLoggingCallback(BaseCallbackPeriodic):

def __init__(
Expand Down Expand Up @@ -521,10 +533,11 @@ def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningM
scores = metric(dataset, lambda x: vae.encode(x.to(vae.device)))
metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items())
log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}')

# log to trainer
prefix = 'final_metric' if is_final else 'epoch_metric'
prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()}
log_metrics(trainer.logger, prefixed_scores)
log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores))

# log summary for WANDB
# this is kinda hacky... the above should work for parallel coordinate plots
Expand Down
8 changes: 6 additions & 2 deletions disent/util/visualize/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def plt_subplots(
assert isinstance(ncols, int)
# check titles
if titles is not None:
titles = np.array(titles).reshape([nrows, ncols])
titles = np.array(titles)
if titles.ndim == 1:
titles = np.array([titles] + ([[None]*ncols] * (nrows-1)))
assert titles.ndim == 2
# get labels
if (row_labels is None) or isinstance(row_labels, str):
row_labels = [row_labels] * nrows
Expand All @@ -161,7 +164,8 @@ def plt_subplots(
ax.set_ylabel(row_labels[y], fontsize=label_size)
# set title
if titles is not None:
ax.set_title(titles[y][x], fontsize=titles_size)
if titles[y][x] is not None:
ax.set_title(titles[y][x], fontsize=titles_size)
# set title
fig.suptitle(title, fontsize=title_size)
# done!
Expand Down
2 changes: 2 additions & 0 deletions experiment/config/run_action/skip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _global_
action: skip
2 changes: 1 addition & 1 deletion experiment/config/schedule/beta_cyclic_fast.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: beta_cyclic
name: beta_cyclic_fast

schedule_items:
beta:
Expand Down
2 changes: 1 addition & 1 deletion experiment/config/schedule/beta_cyclic_slow.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: beta_cyclic
name: beta_cyclic_slow

schedule_items:
beta:
Expand Down
Loading

0 comments on commit 8a53282

Please sign in to comment.