diff --git a/disent/dataset/data/_groundtruth__dsprites_imagenet.py b/disent/dataset/data/_groundtruth__dsprites_imagenet.py index 008aee28..13b80134 100644 --- a/disent/dataset/data/_groundtruth__dsprites_imagenet.py +++ b/disent/dataset/data/_groundtruth__dsprites_imagenet.py @@ -62,9 +62,13 @@ def __getitem__(self, idx): return np.array(img) +def _noop(x): + return x + + def load_imagenet_tiny_data(raw_data_dir): data = NumpyFolder(os.path.join(raw_data_dir, 'train')) - data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=lambda x: x) + data = DataLoader(data, batch_size=64, num_workers=min(16, psutil.cpu_count(logical=False)), shuffle=False, drop_last=False, collate_fn=_noop) # load data - this is a bit memory inefficient doing it like this instead of with a loop into a pre-allocated array imgs = np.concatenate(list(tqdm(data, 'loading')), axis=0) assert imgs.shape == (100_000, 64, 64, 3) diff --git a/disent/dataset/data/_groundtruth__xyobject.py b/disent/dataset/data/_groundtruth__xyobject.py index 9c7db0fe..af3835f5 100644 --- a/disent/dataset/data/_groundtruth__xyobject.py +++ b/disent/dataset/data/_groundtruth__xyobject.py @@ -116,19 +116,7 @@ def __init__( rgb: bool = True, palette: str = 'rainbow_4', transform=None, - warn_: bool = True ): - if warn_: - warnings.warn( - '`XYObjectData` defaults were changed in disent v0.3.0, if you want `approx` <= v0.2.x behavior then use the following parameters. Pallets also changed slightly too.' - '\n\tgrid_size=64' - '\n\tgrid_spacing=1' - '\n\tmin_square_size=3' - '\n\tmax_square_size=9' - '\n\tsquare_size_spacing=2' - '\n\trgb=True' - '\n\tpalette="colors_1"' - ) # generation self._rgb = rgb # check the pallete name @@ -238,7 +226,6 @@ def __init__( rgb=rgb, palette=f'{palette}_1', transform=transform, - warn_=False, ) def _get_observation(self, idx): diff --git a/disent/dataset/transform/_augment.py b/disent/dataset/transform/_augment.py index 20ea3d2d..69082a1d 100644 --- a/disent/dataset/transform/_augment.py +++ b/disent/dataset/transform/_augment.py @@ -185,8 +185,10 @@ class FftKernel(DisentModule): def __init__(self, kernel: Union[torch.Tensor, str], normalize: bool = True): super().__init__() - # load the kernel - self._kernel = torch.nn.Parameter(get_kernel(kernel, normalize=normalize), requires_grad=False) + # load & save the kernel -- no gradients allowed + self._kernel: torch.Tensor + self.register_buffer('_kernel', get_kernel(kernel, normalize=normalize), persistent=True) + self._kernel.requires_grad = False def forward(self, obs): # add or remove batch dim diff --git a/disent/dataset/util/stats.py b/disent/dataset/util/stats.py index c08f66a3..08be5c04 100644 --- a/disent/dataset/util/stats.py +++ b/disent/dataset/util/stats.py @@ -96,70 +96,113 @@ def main(progress=False): data.DSpritesData, data.SmallNorbData, data.Shapes3dData, - wrapped_partial(data.Mpi3dData, subset='toy', in_memory=True), - wrapped_partial(data.Mpi3dData, subset='realistic', in_memory=True), - wrapped_partial(data.Mpi3dData, subset='real', in_memory=True), # groundtruth -- impl synthetic data.XYObjectData, data.XYObjectShadedData, + # large datasets + (data.Mpi3dData, dict(subset='toy', in_memory=True)), + (data.Mpi3dData, dict(subset='realistic', in_memory=True)), + (data.Mpi3dData, dict(subset='real', in_memory=True)), ]: from disent.dataset.transform import ToImgTensorF32 + # get arguments + if isinstance(data_cls, tuple): + data_cls, kwargs = data_cls + else: + data_cls, kwargs = data_cls, {} # Most common standardized way of computing the mean and std over observations # resized to 64px in size of dtype float32 in the range [0, 1]. - data = data_cls(transform=ToImgTensorF32(size=64)) + data = data_cls(transform=ToImgTensorF32(size=64), **kwargs) mean, std = compute_data_mean_std(data, progress=progress) # results! - print(f'{data.__class__.__name__} - {data.name}:\n mean: {mean.tolist()}\n std: {std.tolist()}') + print(f'{data.__class__.__name__} - {data.name} - {kwargs}:\n mean: {mean.tolist()}\n std: {std.tolist()}') # RUN! main() # ========================================================================= # -# RESULTS: 2021-10-12 # +# RESULTS: 2021-11-12 # # ========================================================================= # -# Cars3dData - cars3d: +# Cars3dData - cars3d - {}: # mean: [0.8976676149976628, 0.8891658020067508, 0.885147515814868] # std: [0.22503195531503034, 0.2399461278981261, 0.24792106319684404] -# DSpritesData - dsprites: +# DSpritesData - dsprites - {}: # mean: [0.042494423521889584] # std: [0.19516645880626055] -# SmallNorbData - smallnorb: +# SmallNorbData - smallnorb - {}: # mean: [0.7520918401088603] -# std: [0.09563879016827262] -# Shapes3dData - 3dshapes: +# std: [0.09563879016827263] +# Shapes3dData - 3dshapes - {}: # mean: [0.502584966788819, 0.5787597566089667, 0.6034499731859578] -# std: [0.2940814043555559, 0.3443979087517214, 0.3661685981524748] -# Mpi3dData - mpi3d_toy: -# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702] -# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426] -# Mpi3dData - mpi3d_realistic: -# mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256] -# std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177] -# Mpi3dData - mpi3d_real: -# mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627] -# std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629] -# XYBlocksData - xyblocks: +# std: [0.2940814043555559, 0.34439790875172144, 0.3661685981524748] + +# XYBlocksData - xyblocks - {}: # mean: [0.10040509259259259, 0.10040509259259259, 0.10040509259259259] # std: [0.21689087652106678, 0.21689087652106676, 0.21689087652106678] -# XYObjectData - xy_object: +# XYObjectData - xy_object - {}: # mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] # std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] -# XYObjectShadedData - xy_object: +# XYObjectShadedData - xy_object - {}: # mean: [0.009818761549013288, 0.009818761549013288, 0.009818761549013288] # std: [0.052632363725245844, 0.05263236372524584, 0.05263236372524585] -# XYSquaresData - xy_squares: +# XYSquaresData - xy_squares - {}: # mean: [0.015625, 0.015625, 0.015625] # std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] -# XYSquaresMinimalData - xy_squares: +# XYSquaresMinimalData - xy_squares_minimal - {}: # mean: [0.015625, 0.015625, 0.015625] # std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] -# XColumnsData - x_columns: +# XColumnsData - x_columns - {}: # mean: [0.125, 0.125, 0.125] # std: [0.33075929223788925, 0.3307592922378891, 0.3307592922378892] +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 8}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] +# overlap between squares for reconstruction loss, 7 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 7}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854] +# overlap between squares for reconstruction loss, 6 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 6}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920855] +# overlap between squares for reconstruction loss, 5 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 5}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920855, 0.12403473458920854] +# overlap between squares for reconstruction loss, 4 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 4}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920854, 0.12403473458920854] +# overlap between squares for reconstruction loss, 3 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 3}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854] +# overlap between squares for reconstruction loss, 2 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 2}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920854, 0.12403473458920854, 0.12403473458920854] +# overlap between squares for reconstruction loss, 1 < 8 +# XYSquaresData - xy_squares - {'grid_size': 8, 'grid_spacing': 1}: +# mean: [0.015625, 0.015625, 0.015625] +# std: [0.12403473458920855, 0.12403473458920855, 0.12403473458920855] +# XYSquaresData - xy_squares - {'rgb': False}: +# mean: [0.046146392822265625] +# std: [0.2096506119375896] + +# Mpi3dData - mpi3d_toy - {'subset': 'toy', 'in_memory': True}: +# mean: [0.22681593831231503, 0.22353985202496676, 0.22666059934624702] +# std: [0.07854112062669572, 0.07319301658077378, 0.0790763900050426] +# Mpi3dData - mpi3d_realistic - {'subset': 'realistic', 'in_memory': True}: +# mean: [0.18240164396358813, 0.20723063241107917, 0.1820551008003256] +# std: [0.09511163559287175, 0.10128881101801782, 0.09428244469525177] +# Mpi3dData - mpi3d_real - {'subset': 'real', 'in_memory': True}: +# mean: [0.13111154099374112, 0.16746449372488892, 0.14051725201807627] +# std: [0.10137409845578041, 0.10087824338375781, 0.10534121043187629] + # ========================================================================= # # END # diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 9821d15e..723d764f 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -241,7 +241,7 @@ def compute_unreduced_loss(self, x_recon, x_targ): class AugmentedReconLossHandler(ReconLossHandler): - def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], kernel_weight=1.0): + def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torch.Tensor], wrap_weight=1.0, aug_weight=1.0): super().__init__(reduction=recon_loss_handler._reduction) # save variables self._recon_loss_handler = recon_loss_handler @@ -251,16 +251,21 @@ def __init__(self, recon_loss_handler: ReconLossHandler, kernel: Union[str, torc # load the kernel self._kernel = FftKernel(kernel=kernel, normalize=True) # kernel weighting - assert 0 <= kernel_weight <= 1, f'kernel weight must be in the range [0, 1] but received: {repr(kernel_weight)}' - self._kernel_weight = kernel_weight + assert 0 <= wrap_weight, f'loss_weight must be in the range [0, inf) but received: {repr(wrap_weight)}' + assert 0 <= aug_weight, f'kern_weight must be in the range [0, inf) but received: {repr(aug_weight)}' + self._wrap_weight = wrap_weight + self._aug_weight = aug_weight + # disable gradients + for param in self.parameters(): + param.requires_grad = False def activate(self, x_partial: torch.Tensor): return self._recon_loss_handler.activate(x_partial) def compute_unreduced_loss(self, x_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor: - aug_loss = self._recon_loss_handler.compute_unreduced_loss(self._kernel(x_recon), self._kernel(x_targ)) - loss = self._recon_loss_handler.compute_unreduced_loss(x_recon, x_targ) - return (1. - self._kernel_weight) * loss + self._kernel_weight * aug_loss + wrap_loss = self._recon_loss_handler.compute_unreduced_loss(x_recon, x_targ) + aug_loss = self._recon_loss_handler.compute_unreduced_loss(self._kernel(x_recon), self._kernel(x_targ)) + return (self._wrap_weight * wrap_loss) + (self._aug_weight * aug_loss) def compute_unreduced_loss_from_partial(self, x_partial_recon: torch.Tensor, x_targ: torch.Tensor) -> torch.Tensor: return self.compute_unreduced_loss(self.activate(x_partial_recon), x_targ) diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py index 22a32d47..366bd4e9 100644 --- a/disent/frameworks/vae/_unsupervised__vae.py +++ b/disent/frameworks/vae/_unsupervised__vae.py @@ -158,6 +158,10 @@ def do_training_step(self, batch, batch_idx): 'recon_loss': recon_loss, 'reg_loss': reg_loss, 'aug_loss': aug_loss, + # ratios + 'ratio_reg': (reg_loss / loss) if (loss != 0) else 0, + 'ratio_rec': (recon_loss / loss) if (loss != 0) else 0, + 'ratio_aug': (aug_loss / loss) if (loss != 0) else 0, } # --------------------------------------------------------------------- # diff --git a/disent/util/lightning/callbacks/_callbacks_vae.py b/disent/util/lightning/callbacks/_callbacks_vae.py index f033a1c6..e846b86d 100644 --- a/disent/util/lightning/callbacks/_callbacks_vae.py +++ b/disent/util/lightning/callbacks/_callbacks_vae.py @@ -122,25 +122,25 @@ def _to_dmat( return dmat -_AE_DIST_NAMES = ('x', 'z_l1', 'x_recon') -_VAE_DIST_NAMES = ('x', 'z_l1', 'kl', 'x_recon') +_AE_DIST_NAMES = ('x', 'z', 'x_recon') +_VAE_DIST_NAMES = ('x', 'z', 'kl', 'x_recon') @torch.no_grad() -def _get_dists_ae(ae: Ae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor): +def _get_dists_ae(ae: Ae, x_a: torch.Tensor, x_b: torch.Tensor): # feed forware z_a, z_b = ae.encode(x_a), ae.encode(x_b) r_a, r_b = ae.decode(z_a), ae.decode(z_b) # distances return [ - recon_loss.compute_pairwise_loss(x_a, x_b), - torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist - recon_loss.compute_pairwise_loss(r_a, r_b), + ae.recon_handler.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist + ae.recon_handler.compute_pairwise_loss(r_a, r_b), ] @torch.no_grad() -def _get_dists_vae(vae: Vae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_b: torch.Tensor): +def _get_dists_vae(vae: Vae, x_a: torch.Tensor, x_b: torch.Tensor): from torch.distributions import kl_divergence # feed forward (z_post_a, z_prior_a), (z_post_b, z_prior_b) = vae.encode_dists(x_a), vae.encode_dists(x_b) @@ -150,19 +150,19 @@ def _get_dists_vae(vae: Vae, recon_loss: ReconLossHandler, x_a: torch.Tensor, x_ kl_ab = 0.5 * kl_divergence(z_post_a, z_post_b) + 0.5 * kl_divergence(z_post_b, z_post_a) # distances return [ - recon_loss.compute_pairwise_loss(x_a, x_b), - torch.norm(z_a - z_b, p=1, dim=-1), # l1 dist - recon_loss._pairwise_reduce(kl_ab), - recon_loss.compute_pairwise_loss(r_a, r_b), + vae.recon_handler.compute_pairwise_loss(x_a, x_b), + torch.norm(z_a - z_b, p=2, dim=-1), # l2 dist + vae.recon_handler._pairwise_reduce(kl_ab), + vae.recon_handler.compute_pairwise_loss(r_a, r_b), ] -def _get_dists_fn(model, recon_loss: ReconLossHandler) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]: +def _get_dists_fn(model: Ae) -> Tuple[Optional[Tuple[str, ...]], Optional[Callable[[object, object], Sequence[Sequence[float]]]]]: # get aggregate function if isinstance(model, Vae): - dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model, recon_loss) + dists_names, dists_fn = _VAE_DIST_NAMES, wrapped_partial(_get_dists_vae, model) elif isinstance(model, Ae): - dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model, recon_loss) + dists_names, dists_fn = _AE_DIST_NAMES, wrapped_partial(_get_dists_ae, model) else: dists_names, dists_fn = None, None return dists_names, dists_fn @@ -303,7 +303,6 @@ def __init__( assert traversal_repeats > 0 self._traversal_repeats = traversal_repeats self._seed = seed - self._recon_loss = make_reconstruction_loss('mse', 'mean') self._plt_block_size = plt_block_size self._plt_show = plt_show self._log_wandb = log_wandb @@ -321,7 +320,7 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): log.warning(f'cannot run {self.__class__.__name__} over non-ground-truth data, skipping!') return # get aggregate function - dists_names, dists_fn = _get_dists_fn(vae, self._recon_loss) + dists_names, dists_fn = _get_dists_fn(vae) if (dists_names is None) or (dists_fn is None): log.warning(f'cannot run {self.__class__.__name__}, unsupported model type: {type(vae)}, must be {Ae.__name__} or {Vae.__name__}') return @@ -488,6 +487,19 @@ def do_step(self, trainer: pl.Trainer, pl_module: pl.LightningModule): plt.show() +def _normalized_numeric_metrics(items: dict): + results = {} + for k, v in items.items(): + if isinstance(v, (float, int)): + results[k] = v + else: + try: + results[k] = float(v) + except: + log.warning(f'SKIPPED: metric with key: {repr(k)}, result has invalid type: {type(v)} with value: {repr(v)}') + return results + + class VaeMetricLoggingCallback(BaseCallbackPeriodic): def __init__( @@ -521,10 +533,11 @@ def _compute_metrics_and_log(self, trainer: pl.Trainer, pl_module: pl.LightningM scores = metric(dataset, lambda x: vae.encode(x.to(vae.device))) metric_results = ' '.join(f'{k}{c.GRY}={c.lMGT}{v:.3f}{c.RST}' for k, v in scores.items()) log.info(f'| {metric.__name__:<{pad}} - time{c.GRY}={c.lYLW}{timer.pretty:<9}{c.RST} - {metric_results}') + # log to trainer prefix = 'final_metric' if is_final else 'epoch_metric' prefixed_scores = {f'{prefix}/{k}': v for k, v in scores.items()} - log_metrics(trainer.logger, prefixed_scores) + log_metrics(trainer.logger, _normalized_numeric_metrics(prefixed_scores)) # log summary for WANDB # this is kinda hacky... the above should work for parallel coordinate plots diff --git a/disent/util/visualize/plot.py b/disent/util/visualize/plot.py index f4bafb81..826594c2 100644 --- a/disent/util/visualize/plot.py +++ b/disent/util/visualize/plot.py @@ -135,7 +135,10 @@ def plt_subplots( assert isinstance(ncols, int) # check titles if titles is not None: - titles = np.array(titles).reshape([nrows, ncols]) + titles = np.array(titles) + if titles.ndim == 1: + titles = np.array([titles] + ([[None]*ncols] * (nrows-1))) + assert titles.ndim == 2 # get labels if (row_labels is None) or isinstance(row_labels, str): row_labels = [row_labels] * nrows @@ -161,7 +164,8 @@ def plt_subplots( ax.set_ylabel(row_labels[y], fontsize=label_size) # set title if titles is not None: - ax.set_title(titles[y][x], fontsize=titles_size) + if titles[y][x] is not None: + ax.set_title(titles[y][x], fontsize=titles_size) # set title fig.suptitle(title, fontsize=title_size) # done! diff --git a/experiment/config/run_action/skip.yaml b/experiment/config/run_action/skip.yaml new file mode 100644 index 00000000..80cf8dcd --- /dev/null +++ b/experiment/config/run_action/skip.yaml @@ -0,0 +1,2 @@ +# @package _global_ +action: skip diff --git a/experiment/config/schedule/beta_cyclic_fast.yaml b/experiment/config/schedule/beta_cyclic_fast.yaml index 5b64fc94..ef8ad3bb 100644 --- a/experiment/config/schedule/beta_cyclic_fast.yaml +++ b/experiment/config/schedule/beta_cyclic_fast.yaml @@ -1,4 +1,4 @@ -name: beta_cyclic +name: beta_cyclic_fast schedule_items: beta: diff --git a/experiment/config/schedule/beta_cyclic_slow.yaml b/experiment/config/schedule/beta_cyclic_slow.yaml index de6f0333..46a5ddc7 100644 --- a/experiment/config/schedule/beta_cyclic_slow.yaml +++ b/experiment/config/schedule/beta_cyclic_slow.yaml @@ -1,4 +1,4 @@ -name: beta_cyclic +name: beta_cyclic_slow schedule_items: beta: diff --git a/experiment/config/schedule/beta_delay.yaml b/experiment/config/schedule/beta_delay.yaml new file mode 100644 index 00000000..a223f278 --- /dev/null +++ b/experiment/config/schedule/beta_delay.yaml @@ -0,0 +1,10 @@ +name: beta_increase + +schedule_items: + beta: + _target_: disent.schedule.Single + start_step: 3600 + end_step: 7200 + r_start: 0.001 + r_end: 1.0 + mode: 'linear' diff --git a/experiment/config/schedule/beta_delay_long.yaml b/experiment/config/schedule/beta_delay_long.yaml new file mode 100644 index 00000000..353e75dc --- /dev/null +++ b/experiment/config/schedule/beta_delay_long.yaml @@ -0,0 +1,10 @@ +name: beta_increase + +schedule_items: + beta: + _target_: disent.schedule.Single + start_step: 7200 + end_step: 14400 + r_start: 0.001 + r_end: 1.0 + mode: 'linear' diff --git a/experiment/run.py b/experiment/run.py index 760ac412..bfe57545 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -64,23 +64,24 @@ # ========================================================================= # -def hydra_check_cuda(cfg): - cuda = cfg.dsettings.trainer.cuda - # set cuda - if cuda in {'try_cuda', None}: - cfg.dsettings.trainer.cuda = torch.cuda.is_available() - if not cuda: +def hydra_get_gpus(cfg) -> int: + use_cuda = cfg.dsettings.trainer.cuda + # check cuda values + if use_cuda in {'try_cuda', None}: + use_cuda = torch.cuda.is_available() + if not use_cuda: log.warning('CUDA was requested, but not found on this system... CUDA has been disabled!') + elif use_cuda: + if not torch.cuda.is_available(): + log.error('`dsettings.trainer.cuda=True` but CUDA is not available on this machine!') + raise RuntimeError('CUDA not available!') else: if not torch.cuda.is_available(): - if cuda: - log.error('trainer.cuda=True but CUDA is not available on this machine!') - raise RuntimeError('CUDA not available!') - else: - log.warning('CUDA is not available on this machine!') + log.info('CUDA is not available on this machine!') else: - if not cuda: - log.warning('CUDA is available but is not being used!') + log.warning('CUDA is available but is not being used!') + # get number of gpus to use + return (1 if use_cuda else 0) def hydra_check_data_paths(cfg): @@ -273,7 +274,7 @@ def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cf # ========================================================================= # -# ACTIONS # +# ACTIONS # # ========================================================================= # @@ -301,9 +302,6 @@ def action_train(cfg: DictConfig): time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') log.info(f'Starting run at time: {time_string}') - # print initial config - log.info(f'Initial Config For Action: {cfg.action}\n\nCONFIG:{make_box_str(OmegaConf.to_yaml(cfg), char_v=":", char_h=".")}') - # -~-~-~-~-~-~-~-~-~-~-~-~- # # cleanup from old runs: @@ -331,7 +329,7 @@ def action_train(cfg: DictConfig): log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") # check CUDA setting - hydra_check_cuda(cfg) + gpus = hydra_get_gpus(cfg) # check data preparation hydra_check_data_paths(cfg) @@ -354,7 +352,7 @@ def action_train(cfg: DictConfig): trainer = set_debug_trainer(pl.Trainer( logger=logger, callbacks=callbacks, - gpus=1 if cfg.dsettings.trainer.cuda else 0, + gpus=gpus, # we do this here too so we don't run the final # metrics, even through we check for it manually. terminate_on_nan=True, @@ -390,11 +388,22 @@ def action_train(cfg: DictConfig): # initialising the training process we cannot capture it! trainer.fit(framework, datamodule=datamodule) + # -~-~-~-~-~-~-~-~-~-~-~-~- # + + # cleanup this run + try: + wandb.finish() + except: + pass + + # -~-~-~-~-~-~-~-~-~-~-~-~- # + # available actions ACTIONS = { 'prepare_data': action_prepare_data, 'train': action_train, + 'skip': lambda *args, **kwargs: None, } diff --git a/setup.py b/setup.py index 3505105e..a57f1318 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.3.1", + version="0.3.2", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(), diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 696a3ced..0ec7593f 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -38,6 +38,7 @@ @pytest.mark.parametrize('args', [ + ['run_action=skip'], ['run_action=prepare_data'], ['run_action=train'], ])