From c6dd7556d98523b232f112076d8ac2736bc33a46 Mon Sep 17 00:00:00 2001 From: DerThorsten Date: Mon, 9 Sep 2019 14:46:52 +0200 Subject: [PATCH 1/3] added better logger --- inferno/utils/singleton.py | 38 ++++++++++++++++++++ inferno/utils/tensorboard_logger.py | 54 +++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 inferno/utils/singleton.py create mode 100644 inferno/utils/tensorboard_logger.py diff --git a/inferno/utils/singleton.py b/inferno/utils/singleton.py new file mode 100644 index 00000000..9d804591 --- /dev/null +++ b/inferno/utils/singleton.py @@ -0,0 +1,38 @@ +# https://stackoverflow.com/questions/31875/is-there-a-simple-elegant-way-to-define-singletons +class Singleton: + """ + A non-thread-safe helper class to ease implementing singletons. + This should be used as a decorator -- not a metaclass -- to the + class that should be a singleton. + + The decorated class can define one `__init__` function that + takes only the `self` argument. Also, the decorated class cannot be + inherited from. Other than that, there are no restrictions that apply + to the decorated class. + + To get the singleton instance, use the `instance` method. Trying + to use `__call__` will result in a `TypeError` being raised. + + """ + + def __init__(self, decorated): + self._decorated = decorated + + def instance(self): + """ + Returns the singleton instance. Upon its first call, it creates a + new instance of the decorated class and calls its `__init__` method. + On all subsequent calls, the already created instance is returned. + + """ + try: + return self._instance + except AttributeError: + self._instance = self._decorated() + return self._instance + + def __call__(self): + raise TypeError('Singletons must be accessed through `instance()`.') + + def __instancecheck__(self, inst): + return isinstance(inst, self._decorated) \ No newline at end of file diff --git a/inferno/utils/tensorboard_logger.py b/inferno/utils/tensorboard_logger.py new file mode 100644 index 00000000..b474f793 --- /dev/null +++ b/inferno/utils/tensorboard_logger.py @@ -0,0 +1,54 @@ +from .. utils.singleton import Singleton +from tensorboardX import SummaryWriter +from functools import partialmethod + +@Singleton +class TensorboardLogger(object): + + def __init__(self): + print("construct") + self.summary_writer = None + + + + def setup(self, *args, **kwargs ): + if self.summary_writer is not None: + raise RuntimeError("set_up can only be called once") + self.summary_writer = SummaryWriter(*args, **kwargs) + + def add_audio(self, *args, **kwargs ): + return self.summary_writer.add_audio(*args, **kwargs) + def add_custom_scalars(self, *args, **kwargs ): + return self.summary_writer.add_custom_scalars(*args, **kwargs) + def add_custom_scalars_marginchart(self, *args, **kwargs ): + return self.summary_writer.add_custom_scalars_marginchart(*args, **kwargs) + def add_custom_scalars_multilinechart(self, *args, **kwargs ): + return self.summary_writer.add_custom_scalars_multilinechart(*args, **kwargs) + def add_figure(self, *args, **kwargs ): + return self.summary_writer.add_figure(*args, **kwargs) + def add_graph(self, *args, **kwargs ): + return self.summary_writer.add_graph(*args, **kwargs) + def add_histogram(self, *args, **kwargs ): + return self.summary_writer.add_histogram(*args, **kwargs) + def add_histogram_raw(self, *args, **kwargs ): + return self.summary_writer.add_histogram_raw(*args, **kwargs) + def add_hparams(self, *args, **kwargs ): + return self.summary_writer.add_hparams(*args, **kwargs) + def add_image(self, *args, **kwargs ): + return self.summary_writer.add_image(*args, **kwargs) + def add_mesh(self, *args, **kwargs ): + return self.summary_writer.add_mesh(*args, **kwargs) + def add_pr_curve(self, *args, **kwargs ): + return self.summary_writer.add_pr_curve(*args, **kwargs) + def add_pr_curve_raw(self, *args, **kwargs ): + return self.summary_writer.add_pr_curve_raw(*args, **kwargs) + def add_scalar(self, *args, **kwargs ): + return self.summary_writer.add_scalar(*args, **kwargs) + def add_scalars(self, *args, **kwargs ): + return self.summary_writer.add_scalars(*args, **kwargs) + def add_text(self, *args, **kwargs ): + return self.summary_writer.add_text(*args, **kwargs) + def add_video(self, *args, **kwargs ): + return self.summary_writer.add_video(*args, **kwargs) + def add_embedding(self, *args, **kwargs ): + return self.summary_writer.add_embedding(*args, **kwargs) From 626ed73533496ea23004dd9515d26c6307e3e9b5 Mon Sep 17 00:00:00 2001 From: DerThorsten Date: Tue, 10 Sep 2019 12:58:39 +0200 Subject: [PATCH 2/3] improved tensorboard summary writer --- inferno/trainers/basic.py | 78 ++++++++ .../trainers/tensorboard_summary_writer.py | 181 ++++++++++++++++++ inferno/utils/tensorboard_logger.py | 54 ------ 3 files changed, 259 insertions(+), 54 deletions(-) create mode 100644 inferno/trainers/tensorboard_summary_writer.py delete mode 100644 inferno/utils/tensorboard_logger.py diff --git a/inferno/trainers/basic.py b/inferno/trainers/basic.py index c90c8ef5..f2ce3da6 100755 --- a/inferno/trainers/basic.py +++ b/inferno/trainers/basic.py @@ -24,6 +24,7 @@ from ..extensions import optimizers from ..extensions import criteria from .callbacks import CallbackEngine +from . tensorboard_summary_writer import TensorboardSummaryWriter from .callbacks import Console from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError @@ -783,6 +784,83 @@ def set_target_batch_dim(self, value): self.target_batch_dim = value return self + + @staticmethod + def tensorboard_summary_writer(): + return TensorboardSummaryWriter.instance() + + def setup_tensorboard_summary_writer(self, + add_audio_every = None, + add_custom_scalars_every = None, + add_custom_scalars_marginchart_every = None, + add_custom_scalars_multilinechart_every = None, + add_figure_every = None, + add_graph_every = None, + add_histogram_every = None, + add_histogram_raw_every = None, + add_hparams_every = None, + add_image_every = None, + add_mesh_every = None, + add_pr_curve_every = None, + add_pr_curve_raw_every = None, + add_scalar_every = None, + add_scalars_every = None, + add_text_every = None, + add_video_every = None, + add_embedding_every = None, + log_directory=None, + **kwargs): + + if log_directory is None: + log_directory = self._log_directory + + if log_directory is not None: + kwargs['logdir'] = log_directory + + instance = Trainer.tensorboard_summary_writer() + instance.setup(trainer=self, **kwargs) + + if add_audio_every is not None: + instance.add_audio_every(add_audio_every) + if add_custom_scalars_every is not None: + instance.add_custom_scalars_every(add_custom_scalars_every) + if add_custom_scalars_marginchart_every is not None: + instance.add_custom_scalars_marginchart_every(add_custom_scalars_marginchart_every) + if add_custom_scalars_multilinechart_every is not None: + instance.add_custom_scalars_multilinechart_every(add_custom_scalars_multilinechart_every) + if add_figure_every is not None: + instance.add_figure_every(add_figure_every) + if add_graph_every is not None: + instance.add_graph_every(add_graph_every) + if add_histogram_every is not None: + instance.add_histogram_every(add_histogram_every) + if add_histogram_raw_every is not None: + instance.add_histogram_raw_every(add_histogram_raw_every) + if add_hparams_every is not None: + instance.add_hparams_every(add_hparams_every) + if add_image_every is not None: + instance.add_image_every(add_image_every) + if add_mesh_every is not None: + instance.add_mesh_every(add_mesh_every) + if add_pr_curve_every is not None: + instance.add_pr_curve_every(add_pr_curve_every) + if add_pr_curve_raw_every is not None: + instance.add_pr_curve_raw_every(add_pr_curve_raw_every) + if add_scalar_every is not None: + instance.add_scalar_every(add_scalar_every) + if add_scalars_every is not None: + instance.add_scalars_every(add_scalars_every) + if add_text_every is not None: + instance.add_text_every(add_text_every) + if add_video_every is not None: + instance.add_video_every(add_video_every) + if add_embedding_every is not None: + instance.add_embedding_every(add_embedding_every) + + + + + def build_logger(self, logger=None, log_directory=None, **kwargs): """ Build the logger. diff --git a/inferno/trainers/tensorboard_summary_writer.py b/inferno/trainers/tensorboard_summary_writer.py new file mode 100644 index 00000000..893ded6c --- /dev/null +++ b/inferno/trainers/tensorboard_summary_writer.py @@ -0,0 +1,181 @@ +from .. utils.singleton import Singleton +from .. utils import train_utils as tu +from tensorboardX import SummaryWriter + + +@Singleton +class TensorboardSummaryWriter(object): + + def __init__(self): + self.summary_writer = None + self.trainer = None + + self._add_audio = None + self._add_custom_scalars = None + self._add_custom_scalars_marginchart = None + self._add_custom_scalars_multilinechart = None + self._add_figure = None + self._add_graph = None + self._add_histogram = None + self._add_histogram_raw = None + self._add_hparams = None + self._add_image = None + self._add_mesh = None + self._add_pr_curve = None + self._add_pr_curve_raw = None + self._add_scalar = None + self._add_scalars = None + self._add_text = None + self._add_video = None + self._add_embedding = None + + def setup(self, trainer, *args, **kwargs ): + if self.summary_writer is not None: + raise RuntimeError("set_up can only be called once") + self.summary_writer = SummaryWriter(*args, **kwargs) + self.trainer = trainer + + def _match(self,frequency): + if frequency is None: + return True + else: + return frequency.match( + epoch_count=self.trainer.epoch_count, + iteration_count=self.trainer.iteration_count, + persistent=True + ) + + def add_audio(self, *args, **kwargs ): + if self._match(self._add_audio): + return self.summary_writer.add_audio(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_custom_scalars(self, *args, **kwargs ): + if self._match(self._add_custom_scalars): + return self.summary_writer.add_custom_scalars(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_custom_scalars_marginchart(self, *args, **kwargs ): + if self._match(self._add_custom_scalars_marginchart): + return self.summary_writer.add_custom_scalars_marginchart(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_custom_scalars_multilinechart(self, *args, **kwargs ): + if self._match(self._add_custom_scalars_multilinechart): + return self.summary_writer.add_custom_scalars_multilinechart(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_figure(self, *args, **kwargs ): + if self._match(self._add_figure): + return self.summary_writer.add_figure(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_graph(self, *args, **kwargs ): + if self._match(self._add_graph): + return self.summary_writer.add_graph(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_histogram(self, *args, **kwargs ): + if self._match(self._add_histogram): + return self.summary_writer.add_histogram(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_histogram_raw(self, *args, **kwargs ): + if self._match(self._add_histogram_raw): + return self.summary_writer.add_histogram_raw(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_hparams(self, *args, **kwargs ): + if self._match(self._add_hparams): + return self.summary_writer.add_hparams(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_image(self, *args, **kwargs ): + if self._match(self._add_image): + return self.summary_writer.add_image(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_mesh(self, *args, **kwargs ): + if self._match(self._add_mesh): + return self.summary_writer.add_mesh(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_pr_curve(self, *args, **kwargs ): + if self._match(self._add_pr_curve): + return self.summary_writer.add_pr_curve(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_pr_curve_raw(self, *args, **kwargs ): + if self._match(self._add_pr_curve_raw): + return self.summary_writer.add_pr_curve_raw(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_scalar(self, *args, **kwargs ): + if self._match(self._add_scalar): + return self.summary_writer.add_scalar(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_scalars(self, *args, **kwargs ): + if self._match(self._add_scalars): + return self.summary_writer.add_scalars(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_text(self, *args, **kwargs ): + if self._match(self._add_text): + return self.summary_writer.add_text(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_video(self, *args, **kwargs ): + if self._match(self._add_video): + return self.summary_writer.add_video(*args, **kwargs, global_step=self.trainer.iteration_count) + def add_embedding(self, *args, **kwargs ): + if self._match(self._add_embedding): + return self.summary_writer.add_embedding(*args, **kwargs, global_step=self.trainer.iteration_count) + + + def add_audio_every(self, frequency): + self._add_audio = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_custom_scalars_every(self, frequency): + self._add_custom_scalars = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_custom_scalars_marginchart_every(self, frequency): + self._add_custom_scalars_marginchart = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_custom_scalars_multilinechart_every(self, frequency): + self._add_custom_scalars_multilinechart = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_figure_every(self, frequency): + self._add_figure = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_graph_every(self, frequency): + self._add_graph = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_histogram_every(self, frequency): + self._add_histogram = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_histogram_raw_every(self, frequency): + self._add_histogram_raw = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_hparams_every(self, frequency): + self._add_hparams = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_image_every(self, frequency): + self._add_image = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_mesh_every(self, frequency): + self._add_mesh = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_pr_curve_every(self, frequency): + self._add_pr_curve = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_pr_curve_raw_every(self, frequency): + self._add_pr_curve_raw = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_scalar_every(self, frequency): + self._add_scalar = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_scalars_every(self, frequency): + self._add_scalars = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_text_every(self, frequency): + self._add_text = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_video_every(self, frequency): + self._add_video = tu.Frequency.build_from(frequency, priority='iterations') + return self + def add_embedding_every(self, frequency): + self._add_embedding = tu.Frequency.build_from(frequency, priority='iterations') + return self + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inferno/utils/tensorboard_logger.py b/inferno/utils/tensorboard_logger.py deleted file mode 100644 index b474f793..00000000 --- a/inferno/utils/tensorboard_logger.py +++ /dev/null @@ -1,54 +0,0 @@ -from .. utils.singleton import Singleton -from tensorboardX import SummaryWriter -from functools import partialmethod - -@Singleton -class TensorboardLogger(object): - - def __init__(self): - print("construct") - self.summary_writer = None - - - - def setup(self, *args, **kwargs ): - if self.summary_writer is not None: - raise RuntimeError("set_up can only be called once") - self.summary_writer = SummaryWriter(*args, **kwargs) - - def add_audio(self, *args, **kwargs ): - return self.summary_writer.add_audio(*args, **kwargs) - def add_custom_scalars(self, *args, **kwargs ): - return self.summary_writer.add_custom_scalars(*args, **kwargs) - def add_custom_scalars_marginchart(self, *args, **kwargs ): - return self.summary_writer.add_custom_scalars_marginchart(*args, **kwargs) - def add_custom_scalars_multilinechart(self, *args, **kwargs ): - return self.summary_writer.add_custom_scalars_multilinechart(*args, **kwargs) - def add_figure(self, *args, **kwargs ): - return self.summary_writer.add_figure(*args, **kwargs) - def add_graph(self, *args, **kwargs ): - return self.summary_writer.add_graph(*args, **kwargs) - def add_histogram(self, *args, **kwargs ): - return self.summary_writer.add_histogram(*args, **kwargs) - def add_histogram_raw(self, *args, **kwargs ): - return self.summary_writer.add_histogram_raw(*args, **kwargs) - def add_hparams(self, *args, **kwargs ): - return self.summary_writer.add_hparams(*args, **kwargs) - def add_image(self, *args, **kwargs ): - return self.summary_writer.add_image(*args, **kwargs) - def add_mesh(self, *args, **kwargs ): - return self.summary_writer.add_mesh(*args, **kwargs) - def add_pr_curve(self, *args, **kwargs ): - return self.summary_writer.add_pr_curve(*args, **kwargs) - def add_pr_curve_raw(self, *args, **kwargs ): - return self.summary_writer.add_pr_curve_raw(*args, **kwargs) - def add_scalar(self, *args, **kwargs ): - return self.summary_writer.add_scalar(*args, **kwargs) - def add_scalars(self, *args, **kwargs ): - return self.summary_writer.add_scalars(*args, **kwargs) - def add_text(self, *args, **kwargs ): - return self.summary_writer.add_text(*args, **kwargs) - def add_video(self, *args, **kwargs ): - return self.summary_writer.add_video(*args, **kwargs) - def add_embedding(self, *args, **kwargs ): - return self.summary_writer.add_embedding(*args, **kwargs) From f9f09ce64648917cebc3834b1b6f0bafe20afded Mon Sep 17 00:00:00 2001 From: DerThorsten Date: Mon, 14 Oct 2019 13:11:05 +0200 Subject: [PATCH 3/3] added new example --- examples/README.txt | 6 - examples/plot_cheap_unet.py | 241 ----------------------- examples/plot_train_side_loss_unet.py | 224 ---------------------- examples/plot_unet_tutorial.py | 263 -------------------------- examples/regularized_mnist.py | 124 ------------ examples/tensorboard_logger.py | 128 +++++++++++++ examples/trainer.py | 75 -------- 7 files changed, 128 insertions(+), 933 deletions(-) delete mode 100755 examples/README.txt delete mode 100644 examples/plot_cheap_unet.py delete mode 100644 examples/plot_train_side_loss_unet.py delete mode 100644 examples/plot_unet_tutorial.py delete mode 100644 examples/regularized_mnist.py create mode 100644 examples/tensorboard_logger.py delete mode 100644 examples/trainer.py diff --git a/examples/README.txt b/examples/README.txt deleted file mode 100755 index 968323cb..00000000 --- a/examples/README.txt +++ /dev/null @@ -1,6 +0,0 @@ - -.. _examples-index: - -Gallery of Examples -=================== - diff --git a/examples/plot_cheap_unet.py b/examples/plot_cheap_unet.py deleted file mode 100644 index bfef8565..00000000 --- a/examples/plot_cheap_unet.py +++ /dev/null @@ -1,241 +0,0 @@ -""" -UNet Tutorial -================================ -A unet example which can be run without a gpu -""" - -############################################################################## -# Preface -# -------------- -# We start with some unspectacular multi purpose imports needed for this example -import matplotlib.pyplot as plt -import torch -from torch import nn -import numpy - - -############################################################################## - -# determine whether we have a gpu -# and should use cuda -USE_CUDA = torch.cuda.is_available() - - -############################################################################## -# Dataset -# -------------- -# For simplicity we will use a toy dataset where we need to perform -# a binary segmentation task. -from inferno.io.box.binary_blobs import get_binary_blob_loaders - -# convert labels from long to float as needed by -# binary cross entropy loss -def label_transform(x): - return torch.from_numpy(x).float() -#label_transform = lambda x : torch.from_numpy(x).float() - -train_loader, test_loader, validate_loader = get_binary_blob_loaders( - size=8, # how many images per {train,test,validate} - train_batch_size=2, - length=256, # <= size of the images - gaussian_noise_sigma=1.4, # <= how noise are the images - train_label_transform = label_transform, - validate_label_transform = label_transform -) - -image_channels = 1 # <-- number of channels of the image -pred_channels = 1 # <-- number of channels needed for the prediction - -if False: - ############################################################################## - # Visualize Dataset - # ~~~~~~~~~~~~~~~~~~~~~~ - fig = plt.figure() - - for i,(image, target) in enumerate(train_loader): - ax = fig.add_subplot(1, 2, 1) - ax.imshow(image[0,0,...]) - ax.set_title('raw data') - ax = fig.add_subplot(1, 2, 2) - ax.imshow(target[0,...]) - ax.set_title('ground truth') - break - fig.tight_layout() - plt.show() - - - - -############################################################################## -# Training -# ---------------------------- -# To train the unet, we use the infernos Trainer class of inferno. -# Since we train many models later on in this example we encapsulate -# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for -# an example dedicated to the trainer itself). -from inferno.trainers import Trainer -from inferno.utils.python_utils import ensure_dir - -def train_model(model, loaders, **kwargs): - - trainer = Trainer(model) - trainer.build_criterion('BCEWithLogitsLoss') - trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001)) - #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs')) - #trainer.save_every((kwargs.get('save_every', 10), 'epochs')) - #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor'))) - trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20)) - - # bind the loaders - trainer.bind_loader('train', loaders[0]) - trainer.bind_loader('validate', loaders[1]) - - if USE_CUDA: - trainer.cuda() - - # do the training - trainer.fit() - - return trainer - - - - -############################################################################## -# Prediction -# ---------------------------- -# The trainer contains the trained model and we can do predictions. -# We use :code:`unwrap` to convert the results to numpy arrays. -# Since we want to do many prediction we encapsulate the -# the prediction in a function -from inferno.utils.torch_utils import unwrap - -def predict(trainer, test_loader, save_dir=None): - - - trainer.eval_mode() - for image, target in test_loader: - - # transfer image to gpu - image = image.cuda() if USE_CUDA else image - - # get batch size from image - batch_size = image.size()[0] - - for b in range(batch_size): - prediction = trainer.apply_model(image) - prediction = torch.nn.functional.sigmoid(prediction) - - image = unwrap(image, as_numpy=True, to_cpu=True) - prediction = unwrap(prediction, as_numpy=True, to_cpu=True) - target = unwrap(target, as_numpy=True, to_cpu=True) - - fig = plt.figure() - - ax = fig.add_subplot(2, 2, 1) - ax.imshow(image[b,0,...]) - ax.set_title('raw data') - - ax = fig.add_subplot(2, 2, 2) - ax.imshow(target[b,...]) - ax.set_title('ground truth') - - ax = fig.add_subplot(2, 2, 4) - ax.imshow(prediction[b,...]) - ax.set_title('prediction') - - fig.tight_layout() - plt.show() - - - -############################################################################## -# Custom UNet -# ---------------------------- -# Often one needs to have a UNet with custom layers. -# Here we show how to implement such a customized UNet. -# To this end we derive from :code:`UNetBase`. -# For the sake of this example we will create -# a Unet which uses depthwise convolutions and might be trained on a CPU -from inferno.extensions.models import UNetBase -from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D,ConvActivation - - -class CheapConv(nn.Module): - def __init__(self, in_channels, out_channels, activated): - super(CheapConv, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - if activated: - self.convs = torch.nn.Sequential( - ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2), - ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) - ) - else: - self.convs = torch.nn.Sequential( - ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2), - Conv2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) - ) - def forward(self, x): - assert x.shape[1] == self.in_channels,"input has wrong number of channels" - x = self.convs(x) - assert x.shape[1] == self.out_channels,"output has wrong number of channels" - return x - - -class CheapConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, activated): - super(CheapConvBlock, self).__init__() - self.activated = activated - self.in_channels = in_channels - self.out_channels = out_channels - if(in_channels != out_channels): - self.start = ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1)) - else: - self.start = None - self.conv_a = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=True) - self.conv_b = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=False) - self.activation = torch.nn.ReLU() - def forward(self, x): - x_input = x - if self.start is not None: - x_input = self.start(x_input) - - x = self.conv_a(x_input) - x = self.conv_b(x) - - x = x + x_input - - if self.activated: - x = self.activation(x) - return x - -class MySimple2DCpUnet(UNetBase): - def __init__(self, in_channels, out_channels, depth=3, residual=False, **kwargs): - super(MySimple2DCpUnet, self).__init__(in_channels=in_channels, out_channels=out_channels, - dim=2, depth=depth, **kwargs) - - def conv_op_factory(self, in_channels, out_channels, part, index): - - # last? - last = part == 'up' and index==0 - return CheapConvBlock(in_channels=in_channels, out_channels=out_channels, activated=not last),False - - - -from inferno.extensions.layers import RemoveSingletonDimension -model_b = torch.nn.Sequential( - CheapConv(in_channels=image_channels, out_channels=4, activated=True), - MySimple2DCpUnet(in_channels=4, out_channels=pred_channels) , - RemoveSingletonDimension(dim=1) -) - - -################################################### -# do the training (with the same functions as before) -trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001) - -################################################### -# do the training (with the same functions as before)1 -predict(trainer=trainer, test_loader=test_loader) - diff --git a/examples/plot_train_side_loss_unet.py b/examples/plot_train_side_loss_unet.py deleted file mode 100644 index ba5394fd..00000000 --- a/examples/plot_train_side_loss_unet.py +++ /dev/null @@ -1,224 +0,0 @@ -""" -Train Side Loss UNet Example -================================ - -In this example a UNet with side supervision -and auxiliary loss implemented - -""" - -############################################################################## -# Imports needed for this example -import torch -import torch.nn as nn -from inferno.io.box.binary_blobs import get_binary_blob_loaders -from inferno.trainers.basic import Trainer - -from inferno.extensions.layers.convolutional import Conv2D -from inferno.extensions.models.res_unet import _ResBlock as ResBlock -from inferno.extensions.models import ResBlockUNet -from inferno.utils.torch_utils import unwrap -from inferno.utils.python_utils import ensure_dir -import pylab - - -############################################################################## -# To create a UNet with side loss we create a new nn.Module class -# which has a ResBlockUNet as member. -# The ResBlockUNet is configured such that the results of the -# bottom convolution and all the results of the up-stream -# convolutions are returned as (side)-output. -# a 1x1 convolutions is used to give the side outputs -# the right number of out_channels and UpSampling is -# used to resize all side-outputs to the full resolution -# of the input. These side `side-predictions` are -# returned by our MySideLossUNet. -# Furthermore, all `side-predictions` are concatenated -# and feed trough another two residual blocks to make -# the final prediction. -class MySideLossUNet(nn.Module): - def __init__(self, in_channels, out_channels, depth=3): - super(MySideLossUNet, self).__init__() - - self.depth = depth - self.unet = ResBlockUNet(in_channels=in_channels, out_channels=in_channels*2, - dim=2, unet_kwargs=dict(depth=depth), - side_out_parts=['bottom', 'up']) - - # number of out channels - self.n_channels_per_output = self.unet.n_channels_per_output - - # 1x1 conv to give the side outs of the unet - # the right number of channels - # and a Upsampling to give the right shape - upscale_factor = 2**self.depth - conv_and_scale = [] - for n_channels in self.n_channels_per_output: - - # conv blocks - conv = Conv2D(in_channels=n_channels, out_channels=out_channels, kernel_size=1) - if upscale_factor > 1: - upsample = nn.Upsample(scale_factor=upscale_factor) - conv_and_scale.append(nn.Sequential(conv, upsample)) - else: - conv_and_scale.append(conv) - - upscale_factor //= 2 - - self.conv_and_scale = nn.ModuleList(conv_and_scale) - - # combined number of channels after concat - # concat side output predictions with main output of unet - self.n_channels_combined = (self.depth + 1)* out_channels + in_channels*2 - - self.final_block = nn.Sequential( - ResBlock(dim=2,in_channels=self.n_channels_combined, out_channels=self.n_channels_combined), - ResBlock(in_channels=self.n_channels_combined, out_channels=out_channels, - dim=2, activated=False), - ) - - def forward(self, input): - outs = self.unet(input) - assert len(outs) == len(self.n_channels_per_output) - - # convert the unet output into the right number of - preds = [None] * len(outs) - for i,out in enumerate(outs): - preds[i] = self.conv_and_scale[i](out) - - # this is the side output - preds = tuple(preds) - - # concat side output predictions with main output of unet - combined = torch.cat(preds + (outs[-1],), 1) - - final_res = self.final_block(combined) - - # return everything - return preds + (final_res,) - -############################################################################## -# We use a custom loss functions which applied CrossEntropyLoss -# to all side outputs. -# The side outputs are weighted in a quadratic fashion and added up -# into a single value -class MySideLoss(nn.Module): - """Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion. - """ - - def __init__(self): - super(MySideLoss, self).__init__() - self.criterion = nn.CrossEntropyLoss(reduce=True) - - w = 1.0 - l = None - - def forward(self, predictions, target): - w = 1.0 - l = None - for p in predictions: - ll = self.criterion(p, target)*w - if l is None: - l = ll - else: - l += ll - w *= 2 - return l - - - -############################################################################## -# Training boilerplate (see :ref:`sphx_glr_auto_examples_trainer.py`) -LOG_DIRECTORY = ensure_dir('log') -SAVE_DIRECTORY = ensure_dir('save') -DATASET_DIRECTORY = ensure_dir('dataset') - - -USE_CUDA = torch.cuda.is_available() - -# Build a residual unet where the last layer is not activated -sl_unet = MySideLossUNet(in_channels=5, out_channels=2) - -model = nn.Sequential( - ResBlock(dim=2, in_channels=1, out_channels=5), - sl_unet -) -train_loader, test_loader, validate_loader = get_binary_blob_loaders( - train_batch_size=3, - length=512, # <= size of the images - gaussian_noise_sigma=1.5 # <= how noise are the images -) - -# Build trainer -trainer = Trainer(model) -trainer.build_criterion(MySideLoss()) -trainer.build_optimizer('Adam') -trainer.validate_every((10, 'epochs')) -#trainer.save_every((10, 'epochs')) -#trainer.save_to_directory(SAVE_DIRECTORY) -trainer.set_max_num_epochs(40) - -# Bind loaders -trainer \ - .bind_loader('train', train_loader)\ - .bind_loader('validate', validate_loader) - -if USE_CUDA: - trainer.cuda() - -# Go! -trainer.fit() - - -############################################################################## -# Predict with the trained network -# and visualize the results - -# predict: -#trainer.load(best=True) -trainer.bind_loader('train', train_loader) -trainer.bind_loader('validate', validate_loader) -trainer.eval_mode() - -if USE_CUDA: - trainer.cuda() - -# look at an example -for img,target in test_loader: - if USE_CUDA: - img = img.cuda() - - # softmax on each of the prediction - preds = trainer.apply_model(img) - preds = [nn.functional.softmax(pred,dim=1) for pred in preds] - preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds] - img = unwrap(img, as_numpy=True, to_cpu=True) - target = unwrap(target, as_numpy=True, to_cpu=True) - - n_plots = len(preds) + 2 - batch_size = preds[0].shape[0] - - for b in range(batch_size): - - fig = pylab.figure() - - ax1 = fig.add_subplot(2,4,1) - ax1.set_title('image') - ax1.imshow(img[b,0,...]) - - ax2 = fig.add_subplot(2,4,2) - ax2.set_title('ground truth') - ax2.imshow(target[b,...]) - - for i,pred in enumerate(preds): - axn = fig.add_subplot(2,4, 3+i) - axn.imshow(pred[b,1,...]) - - if i + 1 < len(preds): - axn.set_title('side prediction %d'%i) - else: - axn.set_title('combined prediction') - - pylab.show() - - break diff --git a/examples/plot_unet_tutorial.py b/examples/plot_unet_tutorial.py deleted file mode 100644 index b30de73d..00000000 --- a/examples/plot_unet_tutorial.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -UNet Tutorial -================================ -A tentative tutorial on the usage -of the unet framework in inferno -""" - -############################################################################## -# Preface -# -------------- -# We start with some unspectacular multi purpose imports needed for this example -import matplotlib.pyplot as plt -import torch -import numpy - -############################################################################## - -# determine whether we have a gpu -# and should use cuda -USE_CUDA = torch.cuda.is_available() - - -############################################################################## -# Dataset -# -------------- -# For simplicity we will use a toy dataset where we need to perform -# a binary segmentation task. -from inferno.io.box.binary_blobs import get_binary_blob_loaders - -# convert labels from long to float as needed by -# binary cross entropy loss -def label_transform(x): - return torch.from_numpy(x).float() -#label_transform = lambda x : torch.from_numpy(x).float() - -train_loader, test_loader, validate_loader = get_binary_blob_loaders( - size=8, # how many images per {train,test,validate} - train_batch_size=2, - length=256, # <= size of the images - gaussian_noise_sigma=1.4, # <= how noise are the images - train_label_transform = label_transform, - validate_label_transform = label_transform -) - -image_channels = 1 # <-- number of channels of the image -pred_channels = 1 # <-- number of channels needed for the prediction - -############################################################################## -# Visualize Dataset -# ~~~~~~~~~~~~~~~~~~~~~~ -fig = plt.figure() - -for i,(image, target) in enumerate(train_loader): - ax = fig.add_subplot(1, 2, 1) - ax.imshow(image[0,0,...]) - ax.set_title('raw data') - ax = fig.add_subplot(1, 2, 2) - ax.imshow(target[0,...]) - ax.set_title('ground truth') - break -fig.tight_layout() -plt.show() - - -############################################################################## -# Simple UNet -# ---------------------------- -# We start with a very simple predefined -# res block UNet. By default, this UNet uses ReLUs (in conjunction with batchnorm) as nonlinearities -# With :code:`activated=False` we make sure that the last layer -# is not activated since we chain the UNet with a sigmoid -# activation function. -from inferno.extensions.models import ResBlockUNet -from inferno.extensions.layers import RemoveSingletonDimension - -model = torch.nn.Sequential( - ResBlockUNet(dim=2, in_channels=image_channels, out_channels=pred_channels, activated=False), - RemoveSingletonDimension(dim=1), - torch.nn.Sigmoid() -) - -############################################################################## -# while the model above will work in principal, it has some drawbacks. -# Within the UNet, the number of features is increased by a multiplicative -# factor while going down, the so-called gain. The default value for the gain is 2. -# Since we start with only a single channel we could either increase the gain, -# or use a some convolutions to increase the number of channels -# before the the UNet. -from inferno.extensions.layers import ConvReLU2D -model_a = torch.nn.Sequential( - ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3), - ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False, - res_block_kwargs=dict(batchnorm=True,size=2)) , - RemoveSingletonDimension(dim=1) - # torch.nn.Sigmoid() -) - - - - - -############################################################################## -# Training -# ---------------------------- -# To train the unet, we use the infernos Trainer class of inferno. -# Since we train many models later on in this example we encapsulate -# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for -# an example dedicated to the trainer itself). -from inferno.trainers import Trainer -from inferno.utils.python_utils import ensure_dir - -def train_model(model, loaders, **kwargs): - - trainer = Trainer(model) - trainer.build_criterion('BCEWithLogitsLoss') - trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001)) - #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs')) - #trainer.save_every((kwargs.get('save_every', 10), 'epochs')) - #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor'))) - trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 200)) - - # bind the loaders - trainer.bind_loader('train', loaders[0]) - trainer.bind_loader('validate', loaders[1]) - - if USE_CUDA: - trainer.cuda() - - # do the training - trainer.fit() - - return trainer - - -trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a', lr=0.01) - - - -############################################################################## -# Prediction -# ---------------------------- -# The trainer contains the trained model and we can do predictions. -# We use :code:`unwrap` to convert the results to numpy arrays. -# Since we want to do many prediction we encapsulate the -# the prediction in a function -from inferno.utils.torch_utils import unwrap - -def predict(trainer, test_loader, save_dir=None): - - - trainer.eval_mode() - for image, target in test_loader: - - # transfer image to gpu - image = image.cuda() if USE_CUDA else image - - # get batch size from image - batch_size = image.size()[0] - - for b in range(batch_size): - prediction = trainer.apply_model(image) - prediction = torch.nn.functional.sigmoid(prediction) - - image = unwrap(image, as_numpy=True, to_cpu=True) - prediction = unwrap(prediction, as_numpy=True, to_cpu=True) - target = unwrap(target, as_numpy=True, to_cpu=True) - - fig = plt.figure() - - ax = fig.add_subplot(2, 2, 1) - ax.imshow(image[b,0,...]) - ax.set_title('raw data') - - ax = fig.add_subplot(2, 2, 2) - ax.imshow(target[b,...]) - ax.set_title('ground truth') - - ax = fig.add_subplot(2, 2, 4) - ax.imshow(prediction[b,...]) - ax.set_title('prediction') - - fig.tight_layout() - plt.show() - -################################################### -# do the prediction -predict(trainer=trainer, test_loader=test_loader) - - - - -############################################################################## -# Custom UNet -# ---------------------------- -# Often one needs to have a UNet with custom layers. -# Here we show how to implement such a customized UNet. -# To this end we derive from :code:`UNetBase`. -# For the sake of this example we will create -# a rather exotic UNet which uses different types -# of convolutions/non-linearities in the different branches -# of the unet -from inferno.extensions.models import UNetBase -from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D -from inferno.extensions.layers.sampling import Upsample - -class MySimple2DUnet(UNetBase): - def __init__(self, in_channels, out_channels, depth=3, **kwargs): - super(MySimple2DUnet, self).__init__(in_channels=in_channels, out_channels=out_channels, - dim=2, depth=depth, **kwargs) - - def conv_op_factory(self, in_channels, out_channels, part, index): - - if part == 'down': - return torch.nn.Sequential( - ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), - ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) - ), False - elif part == 'bottom': - return torch.nn.Sequential( - ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), - ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3), - ), False - elif part == 'up': - # are we in the very last block? - if index == 0: - return torch.nn.Sequential( - ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), - Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) - ), False - else: - return torch.nn.Sequential( - ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3), - ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3) - ), False - else: - raise RuntimeError("something is wrong") - - - - - # this function CAN be implemented, if not, MaxPooling is used by default - def downsample_op_factory(self, index): - return torch.nn.MaxPool2d(kernel_size=2, stride=2) - - # this function CAN be implemented, if not, Upsampling is used by default - def upsample_op_factory(self, index): - return Upsample(mode='bilinear', align_corners=False,scale_factor=2) - -model_b = torch.nn.Sequential( - ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3), - MySimple2DUnet(in_channels=5, out_channels=pred_channels) , - RemoveSingletonDimension(dim=1) -) - - -################################################### -# do the training (with the same functions as before) -trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001) - -################################################### -# do the training (with the same functions as before) -predict(trainer=trainer, test_loader=test_loader) - diff --git a/examples/regularized_mnist.py b/examples/regularized_mnist.py deleted file mode 100644 index e2f871c6..00000000 --- a/examples/regularized_mnist.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -Regularized MNIST Example -================================ - -This example demonstrates adding and logging arbitrary regularization losses, in this case, -L2 activity regularization and L1 weight regularization. - -- Add a `_losses` dictionary to any module containing loss names and values -- Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses -- Call `Trainer.observe_training_and_validation_states` to log the losses as well -""" - -import argparse -import sys - -import torch -import torch.nn as nn -from torchvision import datasets, transforms - -from inferno.extensions.layers.reshape import Flatten -from inferno.trainers.basic import Trainer -from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger - - -class RegularizedLinear(nn.Linear): - def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs): - super(RegularizedLinear, self).__init__(*args, **kwargs) - self.ar_weight = ar_weight - self.l1_weight = l1_weight - self._losses = {} - - def forward(self, input): - output = super(RegularizedLinear, self).forward(input) - self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight - self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight - return output - - -def model_fn(): - return nn.Sequential( - Flatten(), - RegularizedLinear(in_features=784, out_features=256), - nn.LeakyReLU(), - RegularizedLinear(in_features=256, out_features=128), - nn.LeakyReLU(), - RegularizedLinear(in_features=128, out_features=10) - ) - - -def mnist_data_loaders(args): - kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} - train_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) - test_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) - return train_loader, test_loader - - -def train_model(args): - model = model_fn() - train_loader, validate_loader = mnist_data_loaders(args) - - # Build trainer - trainer = Trainer(model) \ - .build_criterion('RegularizedCrossEntropyLoss') \ - .build_metric('CategoricalError') \ - .build_optimizer('Adam') \ - .validate_every((1, 'epochs')) \ - .save_every((1, 'epochs')) \ - .save_to_directory(args.save_directory) \ - .set_max_num_epochs(args.epochs) \ - .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), - log_images_every='never'), - log_directory=args.save_directory) - - # Record regularization losses - trainer.logger.observe_training_and_validation_states([ - 'main_loss', - 'total_regularization_loss', - 'activity_regularization', - 'l1_weight_regularization' - ]) - - # Bind loaders - trainer \ - .bind_loader('train', train_loader) \ - .bind_loader('validate', validate_loader) - - if args.cuda: - trainer.cuda() - - # Go! - trainer.fit() - - -def main(argv): - # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--save-directory', type=str, default='output/mnist/v1', - help='output directory') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=20, metavar='N', - help='number of epochs to train (default: 20)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - args = parser.parse_args(argv) - args.cuda = not args.no_cuda and torch.cuda.is_available() - train_model(args) - - -if __name__ == '__main__': - main(sys.argv[1:]) diff --git a/examples/tensorboard_logger.py b/examples/tensorboard_logger.py new file mode 100644 index 00000000..d35f54e7 --- /dev/null +++ b/examples/tensorboard_logger.py @@ -0,0 +1,128 @@ +# general imports +import multiprocessing +import os +import numpy + +# torch imports +import torch +from torch import nn +import torch.utils.data as data +from torchvision import datasets + +# inferno imports +from inferno.trainers.basic import Trainer + + +# access logger from any file +tb_logger = Trainer.tensorboard_summary_writer() + + +class FlatMNist(data.Dataset): + + def __init__(self): + super().__init__() + self.mnist = datasets.MNIST(root='.', download=True) + + def __len__(self): + return len(self.mnist) + + def __getitem__(self, i): + img,l = self.mnist[i] + one_hot = torch.zeros(10) + one_hot[l] = 1 + img = numpy.array(img).astype('float32') /255.0 + #img -= 0.485 + #img /= 0.229 + flat_mnist = img.reshape([784]) + return flat_mnist,one_hot, flat_mnist,l + + +class MyLoss(nn.Module): + def __init__(self): + super().__init__() + #self.mse = nn.MSELoss() + self.rec_loss = nn.BCELoss(reduction='sum') + + def forward(self, output, targets): + rec, mu, logvar = output + y_rec,y_labels = targets + + as_img = y_rec.view([-1, 1, 28, 28]) + as_img = as_img.repeat([1,3,1,1]) + + + tb_logger.add_embedding(mu, metadata=y_labels, label_img=as_img) + + + rec_loss = self.rec_loss(rec, y_rec) + kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + scaled_kld = 0.001*kld + total = rec_loss + scaled_kld + tb_logger.add_scalars('loss', { + 'rec_loss':rec_loss, + 'kld':kld, + 'scaled_kld':scaled_kld, + 'total':total + }) + + + return total + +class VAE(nn.Module): + def __init__(self): + super(VAE, self).__init__() + + self.fc1 = nn.Linear(784+10, 400) + self.fc21 = nn.Linear(400, 20) + self.fc22 = nn.Linear(400, 20) + self.fc3 = nn.Linear(20, 400) + self.fc4 = nn.Linear(400, 784) + self.relu = nn.ReLU() + + def encode(self, x, y): + x = torch.cat([x,y], dim=1) + h1 = self.relu(self.fc1(x)) + return self.fc21(h1), self.fc22(h1) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5*logvar) + eps = torch.randn_like(std) + return mu + eps*std + + def decode(self, z, y): + #z = torch.cat([z,y], dim=1) + h3 = self.relu(self.fc3(z)) + return torch.sigmoid(self.fc4(h3)) + + def forward(self, x, y): + mu, logvar = self.encode(x.view(-1, 784),y) + z = self.reparameterize(mu, logvar) + return self.decode(z, y), mu, logvar + + +# Fill these in: +out_dir = 'somedir' +if not os.path.exists(out_dir): + os.makedirs(out_dir) + +ds = FlatMNist() +train_loader = data.DataLoader(ds, batch_size=3000, + num_workers=multiprocessing.cpu_count()) +model = VAE() +trainer = Trainer(model) +trainer.setup_tensorboard_summary_writer( + log_directory=out_dir, + add_scalars_every=(1, 'iteration'), + add_embedding_every=(1, 'epoch') +) +trainer.cuda() +trainer.save_to_directory(out_dir) +trainer.build_criterion(MyLoss()) +trainer.build_optimizer('Adam',lr=0.01) +trainer.save_every((1, 'epochs')) +trainer.save_to_directory(out_dir) +trainer.set_max_num_epochs(100000) + +# bind callbacks +trainer.bind_loader('train', train_loader, num_inputs=2, num_targets=2) +trainer.fit() diff --git a/examples/trainer.py b/examples/trainer.py deleted file mode 100644 index fa481626..00000000 --- a/examples/trainer.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -Trainer Example -================================ - -This example should illustrate how to use the trainer class. - -""" - -import torch.nn as nn -from inferno.io.box.cifar import get_cifar10_loaders -from inferno.trainers.basic import Trainer -from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger -from inferno.extensions.layers import ConvELU2D -from inferno.extensions.layers import Flatten -from inferno.utils.python_utils import ensure_dir - -from inferno.extensions.layers import SELU - -################################################## -# change directories to your needs -LOG_DIRECTORY = ensure_dir('log') -SAVE_DIRECTORY = ensure_dir('save') -DATASET_DIRECTORY = ensure_dir('dataset') - -################################################## -# shall models be downloaded -DOWNLOAD_CIFAR = True -USE_CUDA = True - -################################################## -# Build torch model -model = nn.Sequential( - ConvELU2D(in_channels=3, out_channels=256, kernel_size=3), - nn.MaxPool2d(kernel_size=2, stride=2), - ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), - nn.MaxPool2d(kernel_size=2, stride=2), - ConvELU2D(in_channels=256, out_channels=256, kernel_size=3), - nn.MaxPool2d(kernel_size=2, stride=2), - Flatten(), - nn.Linear(in_features=(256 * 4 * 4), out_features=10), - nn.Softmax() -) - -################################################## -# data loaders -train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY, - download=DOWNLOAD_CIFAR) - -################################################## -# Build trainer -trainer = Trainer(model) -trainer.build_criterion('CrossEntropyLoss') -trainer.build_metric('CategoricalError') -trainer.build_optimizer('Adam') -trainer.validate_every((2, 'epochs')) -trainer.save_every((5, 'epochs')) -trainer.save_to_directory(SAVE_DIRECTORY) -trainer.set_max_num_epochs(10) -trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'), - log_images_every='never'), - log_directory=LOG_DIRECTORY) - -################################################## -# Bind loaders -trainer.bind_loader('train', train_loader) -trainer.bind_loader('validate', validate_loader) - -################################################## -# activate cuda -if USE_CUDA: - trainer.cuda() - -################################################## -# fit -trainer.fit()