diff --git a/invtf/data.py b/invtf/data.py new file mode 100644 index 0000000..1219f31 --- /dev/null +++ b/invtf/data.py @@ -0,0 +1,20 @@ +import tensorflow as tf +import os + +def load_image_dataset(folder, epochs=1,new_image_size=(64, 64),batch_size=32,shuffle=True,): + def _parse_function(filename): + image_string = tf.io.read_file(filename) + image_decoded = tf.image.decode_jpeg(image_string) + image_resized = tf.image.resize(image_decoded, new_size) + return image_resized + + files = ['{}/{}'.format(folder,f) for f in os.listdir( + folder) if os.path.isfile(os.path.join(folder, f))] + dataset = tf.data.Dataset.from_tensor_slices(tf.constant(files)) + if shuffle == True: + dataset = dataset.shuffle(buffer_size=100) + dataset = dataset.repeat(count=epochs) + dataset = dataset.map(map_func=_parse_function,num_parallel_calls=4) + dataset = dataset.prefetch(buffer_size=batch_size) + dataset = dataset.batch(batch_size=batch_size) + return iter(dataset),(len(files)//batch_size) diff --git a/invtf/generator_const_backprop.py b/invtf/generator_const_backprop.py new file mode 100644 index 0000000..1871065 --- /dev/null +++ b/invtf/generator_const_backprop.py @@ -0,0 +1,418 @@ +""" + Contains the generator class with constant memory depth backprop + +""" + +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ['TF_CPP_MIN_VLOG_LEVEL']='3' + +from tqdm import tqdm +import tensorflow as tf +import invtf.grow_memory +import tensorflow.keras as keras +import numpy as np +import invtf.latent +import matplotlib.pyplot as plt +from invtf.dequantize import * +from invtf.layers import * +from invtf import latent + + + +""" + + TODO: + + - Support specifying different latent distributions, see e.g. NICE. + + - The fit currently uses a dummy 'y=X'. It is not used, but removing it causes an error with 'total_loss'. + Removing might speed up. + + Comments: + We are miss-using the Sequential thing as it is normally just a linear stack of layers. + If we use the multi-scale architecture this is not the case, as it has multiple outputs. + +""" +class Generator(keras.Sequential): + + def __init__(self, latent=latent.Normal(28**2)): + self.latent = latent + + super(Generator, self).__init__() + + + + # Sequential is normally only for linear stack, however, the multiple outputs in multi-scale architecture + # is fairly straight forward, so we change Sequential slightly to allow multiple outputs just for the + # case of the MultiScale layer. Refactor this to make a new variant MutliSqualeSequential which + # Generator inherents from. + + def add(self, layer): + from tensorflow.python.keras.utils import tf_utils + from tensorflow.python.keras.engine import training_utils + from tensorflow.python.util import nest + from tensorflow.python.keras.utils import layer_utils + from tensorflow.python.util import tf_inspect + + + # If we are passed a Keras tensor created by keras.Input(), we can extract + # the input layer from its keras history and use that without any loss of + # generality. + if hasattr(layer, '_keras_history'): + origin_layer = layer._keras_history[0] + if isinstance(origin_layer, keras.layers.InputLayer): + layer = origin_layer + + if not isinstance(layer, keras.layers.Layer): + raise TypeError('The added layer must be ' + 'an instance of class Layer. ' + 'Found: ' + str(layer)) + + tf_utils.assert_no_legacy_layers([layer]) + + self.built = False + set_inputs = False + if not self._layers: + if isinstance(layer, keras.layers.InputLayer): + # Corner case where the user passes an InputLayer layer via `add`. + assert len(nest.flatten(layer._inbound_nodes[-1].output_tensors)) == 1 + set_inputs = True + else: + batch_shape, dtype = training_utils.get_input_shape_and_dtype(layer) + if batch_shape: + # Instantiate an input layer. + x = keras.layers.Input( + batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') + # This will build the current layer + # and create the node connecting the current layer + # to the input layer we just created. + layer(x) + set_inputs = True + + if set_inputs: + # If an input layer (placeholder) is available. + if len(nest.flatten(layer._inbound_nodes[-1].output_tensors)) != 1: + raise ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + self.outputs = [ + nest.flatten(layer._inbound_nodes[-1].output_tensors)[0] + ] + self.inputs = layer_utils.get_source_inputs(self.outputs[0]) + + elif self.outputs: + # If the model is being built continuously on top of an input layer: + # refresh its output. + output_tensor = layer(self.outputs[0]) + if len(nest.flatten(output_tensor)) != 1 and not isinstance(layer, MultiScale): + raise TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + self.outputs = [output_tensor] + + if self.outputs: + # True if set_inputs or self._is_graph_network or if adding a layer + # to an already built deferred seq model. + self.built = True + + if set_inputs or self._is_graph_network: + self._init_graph_network(self.inputs, self.outputs, name=self.name) + else: + self._layers.append(layer) + if self._layers: + self._track_layers(self._layers) + + self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) + + + + # def predict(self, X, dequantize=True): + + # Zs = [] + + # for layer in self.layers: + + # # allow deactivating dequenatize + # # refactor to just look into name of layer and skip if it has dequantize in name or something like that. + # if not dequantize and isinstance(layer, UniformDequantize): continue + # if not dequantize and isinstance(layer, VariationalDequantize): continue + + # # if isinstance(layer, MultiScale): + # # X, Z = layer.call(X) + # # Zs.append(Z) + # # continue + + # X = layer.call(X) + + # # TODO: make sure this does not break case without multiscale architecture. + # # append Zs to X;; do by vectorize and then concat. + + # return X, Zs + + def predict_inv(self, X, Z=None): + n = X.shape[0] + + for layer in self.layers[::-1]: + + if isinstance(layer, MultiScale): + X = layer.call_inv(X, Z.pop()) + + else: + X = layer.call_inv(X) + + return np.array(X, dtype=np.int32) # makes it easier on matplotlib. + + def log_det(self): + logdet = 0. + + for layer in self.layers: + if isinstance(layer, tf.keras.layers.InputLayer): continue + logdet += layer.log_det() + return logdet + + + def loss(self, y_pred): + # computes negative log likelihood in bits per dimension. + # We are overriding the fit function, so we do not need to conform to tf.keras's pointless args. + return self.loss_log_det( y_pred) + self.loss_log_latent_density( y_pred) + + def loss_log_det(self, y_pred): + # divide by /d to get per dimension and divide by log(2) to get from log base E to log base 2. + d = tf.cast(tf.reduce_prod(y_pred.shape[1:]), tf.float32) + norm = d * np.log(2.) + log_det = self.log_det() / norm + + return - log_det + + + def loss_log_latent_density(self, y_pred): + # divide by /d to get per dimension and divide by log(2) to get from log base E to log base 2. + batch_size = tf.cast(tf.shape(y_pred)[0], tf.float32) + d = tf.cast(tf.reduce_prod(y_pred.shape[1:]), tf.float32) + norm = d * np.log(2.) + normal = self.latent.log_density(y_pred) / (norm * batch_size) + + return - normal + + def compile(self, **kwargs): + # overrides what'ever loss the user specifieds; change to complain with exception if they specify it with + #TODO remove this function, since we are overriding fit, we don't need this + kwargs['loss'] = self.loss + + def lg_det(y_true, y_pred): return self.loss_log_det(y_true, y_pred) + def lg_latent(y_true, y_pred): return self.loss_log_latent_density(y_true, y_pred) + def lg_perfect(y_true, y_pred): return self.loss_log_latent_density(y_true, self.latent.sample(n=1000)) + + kwargs['metrics'] = [lg_det, lg_latent, lg_perfect] + + super(Generator, self).compile(**kwargs) + + def train_on_batch(self,X,optimizer=None): + ''' + Computes gradients efficiently and updates weights + Returns - Loss on the batch + TODO - see keras.engine.train_generator.py , they use a similar function. + ''' + x = self.call(X) #I think putting this in context records all operations onto the tape, thereby destroying purpose of checkpointing... + last_layer = self.layers[-1] + #Computing gradients of loss function wrt the last acticvation + with tf.GradientTape() as tape: + tape.watch(x) + loss = self.loss(x) #May have to change + grads_combined = tape.gradient(loss,[x]) + dy = grads_combined[0] + y = x + #Computing gradients for each layer + for layer in self.layers[::-1]: + x = layer.call_inv(y) + dy,grads = layer.compute_gradients(x,dy,layer.log_det) #TODO implement scaling here... + optimizer.apply_gradients(zip(gradientsrads,layer.trainable_variables)) + y = x + return loss + + def fit(self, X, batch_size=32,epochs=1,verbose=1,validation_split=0.0, + validation_data=None, + shuffle=True, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_freq=1, + optimizer=tf.optimizers.Adam(),**kwargs): + ''' + Fits the model on dataset `X (not a generator) + Note - for very big datasets, the function will give OOM, + consider using a generator + Args- + X - Data to be fitted. Maybe one of the following- + tf.EagerTensor + np.ndarray + batch_size - Number of elements in each minibatch + verbose - Logging level + validation_split - Amount of data to be used for validation in each epoch + For tensors or arrays, data is extracted from initial part of dataset. + shuffle - Should training data be shuffled before mini-batches are extracted + steps_per_epoch - Number of training steps per epoch. Used mainly for generators. + validation_steps - Number of validation steps per epoch. Used mainly for generators. + + ''' + # TODO add all callbacks from tf.keras.Model.fit + # TODO return a history object instead of array of losses + all_losses = [] + if validation_split > 0 and validation_data is None: + validation_data = X[:int(len(X)*validation_split)] + X = X[int(len(X)*validation_split):] + + epoch_gen = range(initial_epoch,epochs) + if verbose == 1: + epoch_gen = tqdm(epoch_gen) + batch_size = min(batch_size,X.shape[0]) #Sanity check + num_batches = X.shape[0] // batch_size + if steps_per_epoch == None: + steps_per_epoch = num_batches + val_count = 0 + + for j in epoch_gen: + if shuffle == True: + X = np.random.permutation(X) #Works for np.ndarray and tf.EagerTensor, however, turns everything to numpy + #Minibatch gradient descent + range_gen = range(steps_per_epoch) + if verbose == 2: + range_gen = tqdm(range_gen) + for i in range_gen: + losses = [] + loss = self.train_on_batch(X[i*batch_size:(i+1)*(batch_size)],optimizer) + losses.append(loss.numpy()) + loss = np.mean(losses) + all_losses+=losses + to_print = 'Epoch: {}/{}, training_loss: {}'.format(j,epochs,loss) + if validation_data is not None and val_count%validation_freq==0: + val_loss = self.loss(validation_data) + to_print += ', val_loss: {}'.format(val_loss.numpy()) #TODO return val_loss somehow + if verbose == 2: + print(to_print) + val_count+=1 + return all_losses + + def fit_generator(self, generator,steps_per_epoch=None,initial_epoch=0, + epochs=1, + verbose=1,validation_data=None, + validation_freq=1, + shuffle=True, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + optimizer=tf.optimizers.Adam(), + **kwargs): + ''' + Fits model on the data generator `generator + IMPORTANT - Please consider using invtf.data.load_image_dataset() + Args - + generator - tf.data.Dataset, tf.keras.utils.Sequence or python generator + validation_data - same type as generator + steps_per_epoch - int, number of batches per epoch. + ''' + #TODO add callbacks and history + all_losses = [] + if isinstance(generator,tf.keras.utils.Sequence): + enqueuer = tf.keras.utils.OrderedEnqueuer(generator,use_multiprocessing,shuffle) + if steps_per_epoch == None: + steps_per_epoch = len(generator) #TODO test this, see if it works for both Sequence and Dataset + enqueuer.start(workers=workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + elif isinstance(generator,tf.data.Dataset): + output_generator = iter(generator) + else: + enqueuer = tf.keras.utils.GeneratorEnqueuer(generator,use_multiprocessing) # Can't shuffle here! + enqueuer.start(workers=workers, max_queue_size=max_queue_size) + output_generator = enqueuer.get() + if validation_data is not None: #Assumption that validation data and generator are same type + if isinstance(generator,tf.keras.utils.Sequence): + val_enqueuer = tf.keras.utils.OrderedEnqueuer(validation_data,use_multiprocessing,shuffle) + val_enqueuer.start(workers=workers, max_queue_size=max_queue_size) + val_generator = val_enqueuer.get() + elif isinstance(generator,tf.data.Dataset): + val_generator = iter(val_generator) + else: + val_enqueuer = tf.keras.utils.GeneratorEnqueuer(validation_data,use_multiprocessing) # Can't shuffle here! + val_enqueuer.start(workers=workers, max_queue_size=max_queue_size) + val_generator = val_enqueuer.get() + + if steps_per_epoch == None: + raise ValueError("steps_per_epoch cannot be None with provided generator") + epoch_gen = range(initial_epoch,epochs) + if verbose == 1: + epoch_gen = tqdm(epoch_gen) + for j in epoch_gen: + range_gen = range(steps_per_epoch) + if verbose == 2: + range_gen = tqdm(range_gen) + for i in range_gen: + losses = [] + loss = self.train_on_batch(next(output_generator),optimizer) + losses.append(loss.numpy()) + loss = np.mean(losses) + to_print = 'Epoch: {}/{}, training_loss: {}'.format(j,epochs,loss) + if validation_data is not None and val_count%validation_freq==0: + val_loss = self.loss(next(val_generator)) + to_print += ', val_loss: {}'.format(val_loss.numpy()) #TODO return val_loss somehow + if verbose == 2: + print(to_print) + all_losses+=losses + val_count+=1 + try: + if enqueuer is not None: + enqueuer.stop() + except: + pass + return all_losses + + def rec(self, X): + + X, Zs = self.predict(X, dequantize=False) # TODO: deactivate dequantize. + rec = self.predict_inv(X, Zs) + return rec + + def check_inv(self, X, precision=10**(-5)): + img_shape = X.shape[1:] + + rec = self.rec(X) + + if not np.allclose(X, rec, atol=precision): + fig, ax = plt.subplots(5, 3) + for i in range(5): + ax[i, 0].imshow(X[i].reshape(img_shape).astype(np.int32)) + ax[i, 0].set_title("Image") + ax[i, 1].imshow(rec[i].reshape(img_shape)) + ax[i, 1].set_title("Reconstruction") + ax[i, 2].imshow((X[i]-rec[i]).reshape(img_shape)) + ax[i, 2].set_title("Difference") + plt.show() + + + def sample(self, n=1000, fix_latent=True): + #Z = self.latent.sample(n=n, fix_latent=fix_latent) + + # Figure out how to handle shape of Z. If no multi-scale arch we want to do reshape below. + # If multi-scale arch we don't want to, predict_inv handles it. Figure out who has the responsibility. + + output_shape = self.layers[-1].output_shape[1:] + + X = self.latent.sample(shape=(n, ) + output_shape) + + for layer in self.layers[::-1]: + + if isinstance(layer, MultiScale): + Z = self.latent.sample(shape=X.shape) + X = layer.call_inv(X, Z) + else: + X = layer.call_inv(X) + + return np.array(X, dtype=np.int32) # makes it easier on matplotlib. + + return fakes + + diff --git a/invtf/layers.py b/invtf/layers_const_backprop.py similarity index 67% rename from invtf/layers.py rename to invtf/layers_const_backprop.py index 25a8c89..f4bfe45 100644 --- a/invtf/layers.py +++ b/invtf/layers_const_backprop.py @@ -5,192 +5,54 @@ from invtf.override import print_summary from invtf.coupling_strategy import * -class ReduceNumBits(keras.layers.Layer): - """ - Glow used 5 bit variant of CelebA. - Flow++ had 3 and 5 bit variants of ImageNet. - These lower bit variants allow better dimensionality reduction. - This layer should be the first within the model. - - This also means subsequent normalization needs to divide by less. - In this sense likelihood is incomparable between different number of bits. - - It seems to work, but it is a bit instable. - """ - def __init__(self, bits=5): # assumes input has 8 bits. - self.bits = 5 - super(ReduceNumBits, self).__init__() - - def call(self, X): - X = tf.dtypes.cast(X, dtype=np.float32) - return X // ( 2**(8-self.bits) ) - - def call_inv(self, Z): - # THIS PART IS NOT INVERTIBLE!! - return Z * (2**(8-self.bits)) - - def log_det(self): return 0. - - - -class ActNorm(keras.layers.Layer): - - """ - The exp parameter allows the scaling to be exp(s) \odot X. - This cancels out the log in the log_det computations. - """ - - def __init__(self, exp=False, **kwargs): - self.exp = exp - super(ActNorm, self).__init__(**kwargs) - - def build(self, input_shape): - - n, h, w, c = input_shape - self.h = h - self.w = w - - self.s = self.add_weight(shape=c, initializer='ones', name="affine_scale") - self.b = self.add_weight(shape=c, initializer='zero', name="affine_bias") - - super(ActNorm, self).build(input_shape) - self.built = True - - def call(self, X): return X * self.s + self.b - def call_inv(self, Z): return (Z - self.b) / self.s - - def log_det(self): return self.h * self.w * tf.reduce_sum(tf.math.log(tf.abs(self.s))) - - def compute_output_shape(self, input_shape): return input_shape - - - - -""" - The affine coupling layer is described in NICE, REALNVP and GLOW. - The description in Glow use a single network to output scale s and transform t, - it seems the description in REALNVP is a bit more general refering to s and t as - different functions. From this perspective Glow change the affine layer to have - weight sharing between s and t. - Specifying a single function is a lot simpler code-wise, we thus use that approach. - - - For now assumes the use of convolutions - -""" -class AffineCoupling(keras.layers.Layer): # Sequential): - - def add(self, layer): self.layers.append(layer) - - unique_id = 1 - - def __init__(self, part=0, strategy=SplitChannelsStrategy()): - super(AffineCoupling, self).__init__(name="aff_coupling_%i"%AffineCoupling.unique_id) - AffineCoupling.unique_id += 1 - self.part = part - self.strategy = strategy - self.layers = [] - self._is_graph_network = False - self.precomputed_log_det = 0. - - def _check_trainable_weights_consistency(self): return True - - def build(self, input_shape): - - # handle the issue with each network output something larger. - _, h, w, c = input_shape - - - h, w, c = self.strategy.coupling_shape(input_shape=(h,w,c)) - - self.layers[0].build(input_shape=(None, h, w, c)) - out_dim = self.layers[0].compute_output_shape(input_shape=(None, h, w, c)) - self.layers[0].output_shape_ = out_dim - - for layer in self.layers[1:]: - layer.build(input_shape=out_dim) - out_dim = layer.compute_output_shape(input_shape=out_dim) - layer.output_shape_ = out_dim - - - super(AffineCoupling, self).build(input_shape) - self.built = True - - def call_(self, X): - - in_shape = tf.shape(X) - n, h, w, c = X.shape - - for layer in self.layers: - X = layer.call(X) - - # TODO: Could have a part of network learned specifically for s,t to not ONLY have wegith sharing? - - # Using strategy from - # https://github.com/openai/glow/blob/eaff2177693a5d84a1cf8ae19e8e0441715b82f8/model.py#L376 - X = tf.reshape(X, (-1, h, w, c*2)) - s = X[:, :, :, ::2] # add a strategy pattern to decide how the output is split. - t = X[:, :, :, 1::2] - #s = tf.math.sigmoid(s) - - #s = X[:, :, w//2:, :] - #t = X[:, :, :w//2, :] - - s = tf.reshape(s, in_shape) - t = tf.reshape(t, in_shape) - - return s, t - - def call(self, X): - - x0, x1 = self.strategy.split(X) - - if self.part == 0: - s, t = self.call_(x1) - x0 = x0*s + t # glow changed order of this? i.e. translate then scale. - - if self.part == 1: - s, t = self.call_(x0) - x1 = x1*s + t - - self.precompute_log_det(s, X) - - X = self.strategy.combine(x0, x1) - return X - - def call_inv(self, Z): - z0, z1 = self.strategy.split(Z) - - if self.part == 0: - s, t = self.call_(z1) - z0 = (z0 - t)/s - if self.part == 1: - s, t = self.call_(z0) - z1 = (z1 - t)/s - - Z = self.strategy.combine(z0, z1) - return Z - - def precompute_log_det(self, s, X): - n = tf.dtypes.cast(tf.shape(X)[0], tf.float32) - self.precomputed_log_det = tf.reduce_sum(tf.math.log(tf.abs(s))) / n - - def log_det(self): return self.precomputed_log_det - - def compute_output_shape(self, input_shape): return input_shape - - def summary(self, line_length=None, positions=None, print_fn=None): - print_summary(self, line_length=line_length, positions=positions, print_fn=print_fn) # fixes stupid issue. - - """ Known issue with multi-scale architecture. The log-det computations normalizes wrt full dimension. """ - -class Linear(keras.layers.Layer): +#TODO Write unit tests +class LayerWithGrads(keras.layers.Layer): + ''' + This is a virtual class from which all layer classes need to inherit + It has the function `compute gradients` which is used for constant + memory backprop. + ''' + def __init__(self,**kwargs): + super(LayerWithGrads,self).__init__(**kwargs) + + def call(self,X): + raise NotImplementedError + + def call_inv(self,X): + raise NotImplementedError + + def compute_gradients(self,x,dy,regularizer=None,scaling=1): + ''' + Computes gradients for backward pass + Args: + x - tensor compatible with forward pass, input to the layer + dy - incoming gradient from backprop + regularizer - function, indicates dependence of loss on weights of layer + Returns + dy - gradients wrt input, to be backpropagated + grads - gradients wrt weights + ''' + #TODO check if log_det of AffineCouplingLayer depends needs a regularizer. + with tf.GradientTape() as tape: + tape.watch(x) + y_ = self.call(x) #Required to register the operation onto the gradient tape + grads_combined = tape.gradient(y_,[x]+self.trainable_variables,output_gradients=dy) + dy,grads = grads_combined[0],grads_combined[1:] + + if regularizer is not None: + with tf.GradientTape() as tape: + reg = -regularizer()/scaling + grads_wrt_reg = tape.gradient(reg, self.trainable_variables) + grads = [a[0]+a[1] for a in zip(grads,grads_wrt_reg) if a[1] is not None] + return dy,grads + +class Linear(LayerWithGrads): def __init__(self, **kwargs): super(Linear, self).__init__(**kwargs) @@ -199,26 +61,26 @@ def build(self, input_shape): assert len(input_shape) == 2 _, d = input_shape - self.W = self.add_weight(shape=(d, d), initializer='identity', name="linear_weight") - self.b = self.add_weight(shape=(d), initializer='zero', name="linear_bias") + self.W = self.add_weight(shape=(d, d), initializer='identity', name="linear_weight") + self.b = self.add_weight(shape=(d), initializer='zero', name="linear_bias") super(Linear, self).build(input_shape) self.built = True - def call(self, X): return X @ self.W + self.b + def call(self, X): return X @ self.W + self.b def call_inv(self, Z): return (Z - self.b) @ tf.linalg.inv(self.W) - def jacobian(self): return self.W + def jacobian(self): return self.W - def log_det(self): return tf.math.log(tf.abs(tf.linalg.det(self.jacobian()))) + def log_det(self): return tf.math.log(tf.abs(tf.linalg.det(self.jacobian()))) def compute_output_shape(self, input_shape): self.output_shape = input_shape return input_shape -class Affine(keras.layers.Layer): +class Affine(LayerWithGrads): """ The exp parameter allows the scaling to be exp(s) \odot X. @@ -234,27 +96,27 @@ def build(self, input_shape): #assert len(input_shape) == 2 d = input_shape[1:] - self.w = self.add_weight(shape=d, initializer='ones', name="affine_scale") - self.b = self.add_weight(shape=d, initializer='zero', name="affine_bias") + self.w = self.add_weight(shape=d, initializer='ones', name="affine_scale") + self.b = self.add_weight(shape=d, initializer='zero', name="affine_bias") super(Affine, self).build(input_shape) self.built = True - def call(self, X): - if self.exp: return X * tf.exp(self.w) + self.b - else: return X * self.w + self.b + def call(self, X): + if self.exp: return X * tf.exp(self.w) + self.b + else: return X * self.w + self.b def call_inv(self, Z): - if self.exp: return (Z - self.b) / tf.exp(self.w) - else: return (Z - self.b) / self.w + if self.exp: return (Z - self.b) / tf.exp(self.w) + else: return (Z - self.b) / self.w - def jacobian(self): return self.w + def jacobian(self): return self.w - def eigenvalues(self): return self.w + def eigenvalues(self): return self.w - def log_det(self): - if self.exp: return tf.reduce_sum(tf.abs(self.eigenvalues())) - else: return tf.reduce_sum(tf.math.log(tf.abs(self.eigenvalues()))) + def log_det(self): + if self.exp: return tf.reduce_sum(tf.abs(self.eigenvalues())) + else: return tf.reduce_sum(tf.math.log(tf.abs(self.eigenvalues()))) def compute_output_shape(self, input_shape): self.output_shape = input_shape @@ -286,12 +148,12 @@ def build(self, input_shape): # random orthogonal matrix # check if tf.linalg.qr and tf.linalg.lu are more stable than scipy. - self.kernel = self.add_weight(initializer=keras.initializers.Orthogonal(), shape=(c, c), name="inv_1x1_conv_P") + self.kernel = self.add_weight(initializer=keras.initializers.Orthogonal(), shape=(c, c), name="inv_1x1_conv_P") super(Inv1x1Conv, self).build(input_shape) self.built = True - def call(self, X): + def call(self, X): _W = tf.reshape(self.kernel, (1,1, self.c, self.c)) return tf.nn.conv2d(X, _W, [1,1,1,1], "SAME") @@ -301,7 +163,7 @@ def call_inv(self, Z): _W = tf.reshape(self.kernel_inv, (1,1, self.c, self.c)) return tf.nn.conv2d(Z, _W, [1,1,1,1], "SAME") - def log_det(self): # det computations are way too instable here.. + def log_det(self): # det computations are way too instable here.. return self.h * self.w * tf.math.log(tf.abs( tf.linalg.det(self.kernel) )) def compute_output_shape(self, input_shape): return input_shape @@ -309,7 +171,7 @@ def compute_output_shape(self, input_shape): return input_shape -class Inv1x1ConvPLU(keras.layers.Layer): +class Inv1x1ConvPLU(LayerWithGrads): """ Based on Glow page 11 appendix B. It is possible to speed up determinant computation by using PLU or QR decomposition @@ -334,7 +196,7 @@ def build(self, input_shape): # random orthogonal matrix # check if tf.linalg.qr and tf.linalg.lu are more stable than scipy. import scipy - w = scipy.linalg.qr(np.random.normal(0, 1, (self.c, self.c)))[0].astype(np.float32) + w = scipy.linalg.qr(np.random.normal(0, 1, (self.c, self.c)))[0].astype(np.float32) P, L, U = scipy.linalg.lu(w) def init_P(self, shape=None, dtype=None): return P @@ -350,7 +212,7 @@ def init_U(self, shape=None, dtype=None): return U L_mask = tf.constant(np.triu(np.ones((c,c)), k=+1), dtype=tf.float32) P_mask = tf.constant(np.tril(np.ones((c,c)), k=-1), dtype=tf.float32) - I = tf.constant(np.identity(c), dtype=tf.float32) + I = tf.constant(np.identity(c), dtype=tf.float32) self.P = self.P * P_mask + I self.L = self.L * L_mask + I @@ -361,15 +223,15 @@ def init_U(self, shape=None, dtype=None): return U self.L_inv = tf.linalg.inv(tf.dtypes.cast(L, dtype=tf.float64)) self.U_inv = tf.linalg.inv(tf.dtypes.cast(U, dtype=tf.float64)) - self.kernel_inv = tf.linalg.inv(self.kernel) # tf.dtypes.cast(self.U_inv @ self.L_inv @ self.P_inv, dtype=tf.float32) + self.kernel_inv = tf.linalg.inv(self.kernel) # tf.dtypes.cast(self.U_inv @ self.L_inv @ self.P_inv, dtype=tf.float32) - #self.I_ = self.kernel @ tf.linalg.inv(self.kernel) - #self.I = self.kernel @ self.kernel_inv + #self.I_ = self.kernel @ tf.linalg.inv(self.kernel) + #self.I = self.kernel @ self.kernel_inv super(Inv1x1Conv, self).build(input_shape) self.built = True - def call(self, X): + def call(self, X): _W = tf.reshape(self.kernel, (1,1, self.c, self.c)) return tf.nn.conv2d(X, _W, [1,1,1,1], "SAME") @@ -377,7 +239,7 @@ def call_inv(self, Z): _W = tf.reshape(self.kernel_inv, (1,1, self.c, self.c)) return tf.nn.conv2d(Z, _W, [1,1,1,1], "SAME") - def log_det(self): # det computations are way too instable here.. + def log_det(self): # det computations are way too instable here.. return self.h * self.w * tf.math.log(tf.abs( tf.linalg.det(self.kernel) )) # Looks fine? def compute_output_shape(self, input_shape): return input_shape @@ -401,7 +263,7 @@ class AdditiveCoupling(keras.Sequential): # refactor to be layer and to support def __init__(self, part=0, strategy=SplitOnHalfStrategy()): # strategy: alternate / split ;; alternate does odd/even, split has upper/lower. super(AdditiveCoupling, self).__init__(name="add_coupling_%i"%AdditiveCoupling.unique_id) AdditiveCoupling.unique_id += 1 - self.part = part + self.part = part self.strategy = strategy @@ -423,41 +285,67 @@ def call_(self, X): X = layer.call(X) return X - def call(self, X): - shape = tf.shape(X) - d = tf.reduce_prod(shape[1:]) - X = tf.reshape(X, (shape[0], d)) + def call(self, X): + shape = tf.shape(X) + d = tf.reduce_prod(shape[1:]) + X = tf.reshape(X, (shape[0], d)) x0, x1 = self.strategy.split(X) - if self.part == 0: x0 = x0 + self.call_(x1) - if self.part == 1: x1 = x1 + self.call_(x0) + if self.part == 0: x0 = x0 + self.call_(x1) + if self.part == 1: x1 = x1 + self.call_(x0) X = self.strategy.combine(x0, x1) - X = tf.reshape(X, shape) + X = tf.reshape(X, shape) return X - def call_inv(self, Z): - shape = tf.shape(Z) - d = tf.reduce_prod(shape[1:]) - Z = tf.reshape(Z, (shape[0], d)) + def call_inv(self, Z): + shape = tf.shape(Z) + d = tf.reduce_prod(shape[1:]) + Z = tf.reshape(Z, (shape[0], d)) z0, z1 = self.strategy.split(Z) - if self.part == 0: z0 = z0 - self.call_(z1) - if self.part == 1: z1 = z1 - self.call_(z0) + if self.part == 0: z0 = z0 - self.call_(z1) + if self.part == 1: z1 = z1 - self.call_(z0) Z = self.strategy.combine(z0, z1) - Z = tf.reshape(Z, shape) + Z = tf.reshape(Z, shape) return Z - def log_det(self): return 0. + def log_det(self): return 0. def compute_output_shape(self, input_shape): return input_shape + def compute_gradients(self,x,dy,regularizer=None,scaling=1): + ''' + Computes gradients for backward pass + Since the coupling layers do not inherit from `LayerWithGrads`, this + function is re-written. See TODO of AffineCoupling for further info + Args: + x - tensor compatible with forward pass, input to the layer + dy - incoming gradient from backprop + regularizer - function, indicates dependence of loss on weights of layer + Returns + dy - gradients wrt input, to be backpropagated + grads - gradients wrt weights + ''' + with tf.GradientTape() as tape: + tape.watch(x) + y_ = self.call(x) #Required to register the operation onto the gradient tape + grads_combined = tape.gradient(y_,[x]+self.trainable_variables,output_gradients=dy) + dy,grads = grads_combined[0],grads_combined[1:] + + if regularizer is not None: + with tf.GradientTape() as tape: + reg = -regularizer() + grads_wrt_reg = tape.gradient(reg, self.trainable_variables) + grads = [a[0]+a[1] for a in zip(grads,grads_wrt_reg)] + return dy,grads + @@ -512,7 +400,7 @@ def compute_output_shape(self, input_shape): return input_shape - Downscale images, e.g. alternate pixels and have 4 lower dim images and stack them. - ... """ -class Squeeze(keras.layers.Layer): +class Squeeze(LayerWithGrads): def call(self, X): n, self.w, self.h, self.c = X.shape @@ -521,7 +409,7 @@ def call(self, X): def call_inv(self, X): return tf.reshape(X, [-1, self.w, self.h, self.c]) - def log_det(self): return 0. + def log_det(self): return tf.zeros((1,)) class UnSqueeze(keras.layers.Layer): @@ -537,7 +425,7 @@ def log_det(self): return 0. -class Normalize(keras.layers.Layer): # normalizes data after dequantization. +class Normalize(LayerWithGrads): # normalizes data after dequantization. """ """ @@ -556,7 +444,7 @@ def build(self, input_shape): self.built = True def call(self, X): - X = X * self.scale - 1 + X = X * self.scale - 1 return X def call_inv(self, Z): @@ -583,15 +471,15 @@ def compute_output_shape(self, input_shape): n, h, w, c = input_shape return (n, h, w, c//2) - def log_det(self): return 0. + def log_det(self): return tf.zeros((1,)) -""" - There's an issue with scaling, which intuitively makes step-size VERY small. -""" -class Conv3DCirc(keras.layers.Layer): +class Conv3DCirc(LayerWithGrads): + """ + There's an issue with scaling, which intuitively makes step-size VERY small. + """ def __init__(self,trainable=True): self.built = False @@ -669,28 +557,232 @@ def __call__(self, w): return w / tf.math.reduce_max(w) # TODO: This needs to be self.built = True -class Reshape(keras.layers.Layer): - def __init__(self, shape): - self.shape = shape - super(Reshape, self).__init__() + +class InvResNet(keras.layers.Layer): pass # model should automatically use gradient checkpointing if this is used. + + +# the 3D case, refactor to make it into the general case. +# make experiment with nD case, maybe put reshape into it? +# Theoretically time is the same? +class CircularConv(keras.layers.Layer): + + def __init__(self, dim=3): # + self.dim = dim + + def call(self, X): pass + + def call_inv(self, X): pass + + def log_det(self): pass + + + + + +""" + The affine coupling layer is described in NICE, REALNVP and GLOW. + The description in Glow use a single network to output scale s and transform t, + it seems the description in REALNVP is a bit more general refering to s and t as + different functions. From this perspective Glow change the affine layer to have + weight sharing between s and t. + Specifying a single function is a lot simpler code-wise, we thus use that approach. + + + For now assumes the use of convolutions + +""" +class AffineCoupling(LayerWithGrads): # Sequential): + + def add(self, layer): self.layers.append(layer) + + unique_id = 1 + + def __init__(self, part=0, strategy=SplitChannelsStrategy()): + super(AffineCoupling, self).__init__(name="aff_coupling_%i"%AffineCoupling.unique_id) + AffineCoupling.unique_id += 1 + self.part = part + self.strategy = strategy + self.layers = [] + self._is_graph_network = False + + + def _check_trainable_weights_consistency(self): return True + + def build(self, input_shape): + + # handle the issue with each network output something larger. + _, h, w, c = input_shape + + + h, w, c = self.strategy.coupling_shape(input_shape=(h,w,c)) + + self.layers[0].build(input_shape=(None, h, w, c)) + out_dim = self.layers[0].compute_output_shape(input_shape=(None, h, w, c)) + self.layers[0].output_shape_ = out_dim + + for layer in self.layers[1:]: + layer.build(input_shape=out_dim) + out_dim = layer.compute_output_shape(input_shape=out_dim) + layer.output_shape_ = out_dim + + + super(AffineCoupling, self).build(input_shape) + self.built = True + + def call_(self, X): + + in_shape = tf.shape(X) + n, h, w, c = X.shape + + for layer in self.layers: + X = layer.call(X) + + # TODO: Could have a part of network learned specifically for s,t to not ONLY have wegith sharing? + + # Using strategy from + # https://github.com/openai/glow/blob/eaff2177693a5d84a1cf8ae19e8e0441715b82f8/model.py#L376 + X = tf.reshape(X, (-1, h, w, c*2)) + s = X[:, :, :, ::2] # add a strategy pattern to decide how the output is split. + t = X[:, :, :, 1::2] + #s = tf.math.sigmoid(s) + + #s = X[:, :, w//2:, :] + #t = X[:, :, :w//2, :] + + s = tf.reshape(s, in_shape) + t = tf.reshape(t, in_shape) + + return s, t + + def call(self, X): + + x0, x1 = self.strategy.split(X) + + if self.part == 0: + s, t = self.call_(x1) + x0 = x0*s + t # glow changed order of this? i.e. translate then scale. + + if self.part == 1: + s, t = self.call_(x0) + x1 = x1*s + t + + self.precompute_log_det(s, X) + # print("s",np.isnan(s),np.isnan(t)) + X = self.strategy.combine(x0, x1) + #Diagnostic statements for testing NaN gradient + # print("s",np.isnan(s).all(),"t",np.isnan(t).all()) + # print("X0",np.isnan(x0).all(),"X1",np.isnan(x1).all()) + return X + + def call_inv(self, Z): + z0, z1 = self.strategy.split(Z) + + if self.part == 0: + s, t = self.call_(z1) + z0 = (z0 - t)/s + if self.part == 1: + s, t = self.call_(z0) + z1 = (z1 - t)/s + + Z = self.strategy.combine(z0, z1) + return Z + + def precompute_log_det(self, s, X): + n = tf.dtypes.cast(tf.shape(X)[0], tf.float32) + self._log_det = tf.reduce_sum(tf.math.log(tf.abs(s))) / n + + def log_det(self): return self._log_det + + def compute_output_shape(self, input_shape): return input_shape + + def summary(self, line_length=None, positions=None, print_fn=None): + print_summary(self, line_length=line_length, positions=positions, print_fn=print_fn) # fixes stupid issue. + + def compute_gradients(self,x,dy,regularizer=None,scaling=1): + ''' + Computes gradients for backward pass + Args: + x - tensor compatible with forward pass, input to the layer + dy - incoming gradient from backprop + regularizer - function, indicates dependence of loss on weights of layer + Returns + dy - gradients wrt input, to be backpropagated + grads - gradients wrt weights + ''' + #TODO check if log_det of AffineCouplingLayer depends needs a regularizer. -- It does + #TODO fix bug of incorrect dy + with tf.GradientTape(persistent=True) as tape: #Since log_det is computed within call + tape.watch(x) + y_ = self.call(x) #Required to register the operation onto the gradient tape + reg = -self._log_det/scaling + grads_combined = tape.gradient(y_,[x]+self.trainable_variables,output_gradients=dy) + grads_wrt_reg = tape.gradient(reg,self.trainable_variables) + grads_of_inp = tape.gradient(reg,[x]) + dy,grads = grads_combined[0],grads_combined[1:] + grads = [a[0]+a[1] for a in zip(grads,grads_wrt_reg)] + n = tf.dtypes.cast(tf.shape(x)[0], tf.float32) + dy = [a[1]+a[0] for a in zip(dy,grads_of_inp)] #TODO check this expression, seems numerically approximate + del tape #Since tape was persistent, we need this + return dy,grads + +class ReduceNumBits(LayerWithGrads): + """ + Glow used 5 bit variant of CelebA. + Flow++ had 3 and 5 bit variants of ImageNet. + These lower bit variants allow better dimensionality reduction. + This layer should be the first within the model. + + This also means subsequent normalization needs to divide by less. + In this sense likelihood is incomparable between different number of bits. + + It seems to work, but it is a bit instable. + """ + def __init__(self, bits=5): # assumes input has 8 bits. + self.bits = 5 + super(ReduceNumBits, self).__init__() def call(self, X): - self.prev_shape = X.shape - return tf.reshape(X, (-1, ) + self.shape) + X = tf.dtypes.cast(X, dtype=np.float32) + return X // ( 2**(8-self.bits) ) - def log_det(self): return .0 + def call_inv(self, Z): + # THIS PART IS NOT INVERTIBLE!! + return Z * (2**(8-self.bits)) - def call_inv(self, X): return tf.reshape(X, self.input_shape) + def log_det(self): return 0. + +class ActNorm(LayerWithGrads): + """ + The exp parameter allows the scaling to be exp(s) \odot X. + This cancels out the log in the log_det computations. + """ + + def __init__(self, exp=False, **kwargs): + self.exp = exp + super(ActNorm, self).__init__(**kwargs) + + def build(self, input_shape): + + n, h, w, c = input_shape + self.h = h + self.w = w + self.s = self.add_weight(shape=c, initializer='ones', name="affine_scale") + self.b = self.add_weight(shape=c, initializer='zero', name="affine_bias") + + super(ActNorm, self).build(input_shape) + self.built = True -class InvResNet(keras.layers.Layer): + def call(self, X): return X * self.s + self.b + def call_inv(self, Z): return (Z - self.b) / self.s + def log_det(self): return self.h * self.w * tf.reduce_sum(tf.math.log(tf.abs(self.s))) + def compute_output_shape(self, input_shape): return input_shape - pass # model should automatically use gradient checkpointing if this is used. diff --git a/test.py b/test.py index e72b3f5..10329b0 100644 --- a/test.py +++ b/test.py @@ -51,10 +51,10 @@ """ import unittest #from test.shape import * -from test.jacobian import * +# from test.jacobian import * #from test.optimality import * -from test.inverse import * - +# from test.inverse import * +from test.gradients import * if __name__ == "__main__": unittest.main() diff --git a/test/gradients.py b/test/gradients.py new file mode 100644 index 0000000..d1aec04 --- /dev/null +++ b/test/gradients.py @@ -0,0 +1,92 @@ +import unittest +import sys +sys.path.append("../") +import invtf.latent +# import invtf.layers +#from tensorflow.python.ops.parallel_for.gradients import jacobian +import tensorflow as tf +import tensorflow.keras as keras +import numpy as np +from invtf.generator_const_backprop import Generator as GenConst +from invtf.layers_const_backprop import * +from tensorflow.keras.layers import ReLU, Dense, Flatten, Conv2D + +class GeneratorGradTest(GenConst): + def prune(self,l): + return [x for sublist in l for x in sublist if len(sublist)>0] + def compute_gradients(self,X): + x = self.call(X) #I think putting this in context records all operations onto the tape, thereby destroying purpose of checkpointing... + last_layer = self.layers[-1] + d = np.prod(X.shape[1:]) + #Computing gradients of loss function wrt the last acticvation + with tf.GradientTape() as tape: + tape.watch(x) + loss = self.loss(x) #May have to change + grads_combined = tape.gradient(loss,[x]) + dy = grads_combined[0] + y = x + #Computing gradients for each layer + gradients = [] + for layer in self.layers[::-1]: + x = layer.call_inv(y) + dy,grads = layer.compute_gradients(x,dy,layer.log_det,d*np.log(2.)) #TODO implement scaling here -- DONE + gradients=[grads]+gradients + y = x + return self.prune(gradients) + + def actual_gradients(self,X): + with tf.GradientTape() as tape: + loss = self.loss(self.call(X)) + grads = tape.gradient(loss,self.trainable_variables) + return grads + +class TestGradients(unittest.TestCase): + X = keras.datasets.cifar10.load_data()[0][0][:5].astype('f') # a single cifar image batch. + + def assertGrad(self,g,X): + computed_grads = g.compute_gradients(X) + actual_grads = g.actual_gradients(X) + A = [np.allclose(np.abs(x[0]-x[1]),0,atol=1, rtol=0.1) for x in zip(computed_grads,actual_grads) if x[0] is not None] + # print("computed",computed_grads,"actual_grads",actual_grads) + print("Max discrepancy in gradients",np.max(np.array([np.max((np.abs(x[0]-x[1]))) for x in zip(computed_grads,actual_grads) if x[0] is not None]))) + self.assertTrue(np.array(A).all()) + + def test_circ_conv(self): + X = TestGradients.X + d = 32*32*3 + g = GeneratorGradTest(invtf.latent.Normal(d)) + g.add(Conv3DCirc()) + g.predict(X[:1]) + self.assertGrad(g,X) + + def test_inv_conv(self): + X = TestGradients.X + d = 32*32*3 + g = GeneratorGradTest(invtf.latent.Normal(d)) + g.add(Inv1x1ConvPLU()) + g.predict(X[:1]) + self.assertGrad(g,X) + + def test_act_norm(self): + X = TestGradients.X + d = 32*32*3 + g = GeneratorGradTest(invtf.latent.Normal(d)) + g.add(ActNorm()) + g.predict(X[:1]) + self.assertGrad(g,X) + + def test_affine_coupling(self): + X = np.random.normal(0,1,(5,2,2,2)).astype('f') + print(X.shape) + d = 2*2*2 + g = GeneratorGradTest(invtf.latent.Normal(d)) + b = AffineCoupling() + b.add(Flatten()) + b.add(Dense(d,activation='sigmoid')) + g.add(Squeeze()) + g.add(Conv3DCirc()) + g.add(b) + g.add(Conv3DCirc()) + # g.predict(X[:1]) + self.assertGrad(g,X) +