diff --git a/Code/lucid_ml/classifying/neural_net.py b/Code/lucid_ml/classifying/neural_net.py index 616826b..cec3c0c 100644 --- a/Code/lucid_ml/classifying/neural_net.py +++ b/Code/lucid_ml/classifying/neural_net.py @@ -3,13 +3,32 @@ from scipy import sparse from sklearn.base import BaseEstimator -from keras.layers import Dense, Activation, Dropout, BatchNormalization +from keras.layers import Dense, Activation, Dropout from keras.models import Sequential from keras.optimizers import Adam import numpy as np from sklearn.metrics import f1_score -from sklearn.linear_model import Ridge - +from sklearn.linear_model import Ridge +from keras.callbacks import EarlyStopping, ModelCheckpoint + +#=============================================================================== +# class EarlyStoppingBySklearnMetric(Callback): +# def __init__(self, metric=lambda y_test, y_pred : f1_score(y_test, y_pred, average='samples'), value=0.00001, verbose=0): +# super(Callback, self).__init__() +# self.metric = metric +# self.value = value +# self.verbose = verbose +# +# def on_epoch_end(self, epoch, logs={}): +# current = logs.get(self.monitor) +# if current is None: +# warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning) +# +# if current < self.value: +# if self.verbose > 0: +# print("Epoch %05d: early stopping THR" % epoch) +# self.model.stop_training = True +#=============================================================================== def _batch_generator(X, y, batch_size, shuffle): number_of_batches = np.ceil(X.shape[0] / batch_size) @@ -43,10 +62,22 @@ def _batch_generatorp(X, batch_size): class MLP(BaseEstimator): - def __init__(self, verbose=0, model=None, final_activation='sigmoid'): + def __init__(self, verbose=0, model=None, final_activation='sigmoid', batch_size = 512, learning_rate = None, epochs = 20): self.verbose = verbose self.model = model self.final_activation = final_activation + self.batch_size = batch_size + self.validation_data_position = None + self.epochs = epochs + + # we scale the learning rate proportionally with the batch size as suggested by + # [Thomas M. Breuel, 2015, The Effects of Hyperparameters on SGD + # Training of Neural Networks] + # we found lr=0.01 to be a good learning rate for batch size 512 + if learning_rate is None: + self.lr = self.batch_size / 512 * 0.01 + else: + self.lr = learning_rate def fit(self, X, y): if not self.model: @@ -56,16 +87,32 @@ def fit(self, X, y): self.model.add(Dropout(0.5)) self.model.add(Dense(y.shape[1])) self.model.add(Activation(self.final_activation)) - self.model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.01)) - self.model.fit_generator(generator=_batch_generator(X, y, 256, True), - samples_per_epoch=X.shape[0], nb_epoch=20, verbose=self.verbose) + self.model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=self.lr)) + + val_pos = self.validation_data_position + + + callbacks = [] + if self.validation_data_position is not None: + callbacks.append(EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='auto')) + callbacks.append(ModelCheckpoint("weights.best.hdf5", monitor='val_loss', verbose=1, save_best_only=True, mode='min')) + X_train, y_train, X_val, y_val = X[:val_pos, :], y[:val_pos,:], X[val_pos:, :], y[val_pos:,:] + else: + X_train, y_train = X, y + self.model.fit_generator(generator=_batch_generator(X_train, y_train, self.batch_size, True), callbacks=callbacks, + steps_per_epoch=int(X.shape[0] / float(self.batch_size)) + 1, nb_epoch=self.epochs, verbose=self.verbose, + validation_data = _batch_generator(X_val, y_val, self.batch_size, False) if self.validation_data_position is not None else None, + validation_steps = 10) + + if self.validation_data_position is not None: + self.model.load_weights("weights.best.hdf5") def predict(self, X): pred = self.predict_proba(X) return sparse.csr_matrix(pred > 0.2) def predict_proba(self, X): - pred = self.model.predict_generator(generator=_batch_generatorp(X, 512), val_samples=X.shape[0]) + pred = self.model.predict_generator(generator=_batch_generatorp(X, self.batch_size), steps=int(X.shape[0] / float(self.batch_size)) + 1) return pred diff --git a/Code/lucid_ml/classifying/tensorflow_models.py b/Code/lucid_ml/classifying/tensorflow_models.py new file mode 100644 index 0000000..0f08f01 --- /dev/null +++ b/Code/lucid_ml/classifying/tensorflow_models.py @@ -0,0 +1,931 @@ +import numpy as np +import tensorflow as tf +from scipy.sparse.csr import csr_matrix +import scipy.sparse as sps +from sklearn.base import BaseEstimator +from sklearn.metrics import f1_score +from sklearn.preprocessing import LabelBinarizer +import math, numbers, os +from tensorflow.python.framework import ops, tensor_shape, tensor_util +from tensorflow.python.ops import math_ops, random_ops, array_ops +from tensorflow.python.layers import utils +from datetime import datetime +from utils.tf_utils import tf_normalize, sequence_length, average_outputs, dynamic_max_pooling +#tf.logging.set_verbosity(tf.logging.INFO) + +def _load_embeddings(filename, vocab_size, embedding_size): + + embeddings = np.random.normal(scale = 0.1, size = (vocab_size, embedding_size)) + with open(filename + ".tmp",'r') as embedding_file: + + i = 0 + for line in embedding_file.readlines(): + row = line.strip().split(' ') + # omit escape sequences + if len(row) != embedding_size + 1: + continue + else: + embeddings[i, :] = np.asarray(row[1:], dtype=np.float32) + i += 1 + return embeddings, embedding_size + +def _embeddings(x_tensor, vocab_size, embedding_size, pretrained_embeddings = True, trainable_embeddings = True): + + if pretrained_embeddings: + embedding_placeholder = tf.placeholder(tf.float32, shape=[vocab_size, embedding_size]) + lookup_table = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), + trainable=trainable_embeddings, name="W") + embedding_init = tf.assign(lookup_table, embedding_placeholder) + + embedded_words = tf.nn.embedding_lookup(lookup_table, x_tensor) + return embedded_words, embedding_init, embedding_placeholder + + else: + lookup_table = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0), + name="W") + embedded_words = tf.nn.embedding_lookup(lookup_table, x_tensor) + return embedded_words + + +def _extract_vocab_size(X): + # max index of vocabulary is encoded in last column + vocab_size = X[0, -1] + 1 + + # slice of the column from the input, as it's not part of the sequence + sequence_length = X.shape[1] + x_tensor = tf.placeholder(tf.int32, shape=(None, sequence_length), name = "x") + feature_input = tf.slice(x_tensor, [0, 0], [-1, sequence_length - 1]) + return x_tensor, vocab_size, feature_input + +def _init_embedding_layer(pretrained_embeddings_path, feature_input, embedding_size, vocab_size, + params_fit, params_predict, trainable_embeddings, initializer_ops): + if pretrained_embeddings_path is not None: + embeddings, embedding_size = _load_embeddings(pretrained_embeddings_path, vocab_size, embedding_size) + + embedded_words, embedding_init, embedding_placeholder = _embeddings(feature_input, vocab_size, embedding_size, + pretrained_embeddings = True, + trainable_embeddings = trainable_embeddings) + initializer_ops.append((embedding_init, {embedding_placeholder : embeddings})) + + else: + embedded_words = _embeddings(feature_input, vocab_size, embedding_size, + pretrained_embeddings = False, trainable_embeddings = trainable_embeddings) + + return embedded_words, embedding_size + +def extract_axis_1(data, ind): + """ + Get specified elements along the first axis of tensor. + :param data: Tensorflow tensor that will be subsetted. + :param ind: Indices to take (one for each element along axis 0 of data). + :return: Subsetted tensor. + """ + + batch_range = tf.range(tf.shape(data)[0]) + indices = tf.stack([batch_range, ind], axis=1) + res = tf.gather_nd(data, indices) + + return res + +def _word_attention(output_state): + # we perform attention according to Hierarchical Attention Network (word attention) + ## compute hidden representation of outputs + ### context vector size as in HN-ATT + context_vector_size = 100 + hidden_output_representation = tf.contrib.layers.fully_connected(output_state, context_vector_size, activation_fn = tf.nn.tanh) + ## compute dot product with context vector + context_vector = tf.Variable(tf.random_normal([context_vector_size], stddev=0.1)) + dot_product = tf.tensordot(hidden_output_representation, context_vector, [[2], [0]]) + ## compute weighted sum + attention_weights = tf.nn.softmax(dot_product) + attention_weights = tf.expand_dims(attention_weights, -1) + output_state = tf.reduce_sum(tf.multiply(output_state, attention_weights), axis = 1) + + return output_state + +def lstm_fn(X, y, keep_prob_dropout = 0.5, embedding_size = 30, hidden_layers = [1000], + aggregate_output = "average", + pretrained_embeddings_path = None, + trainable_embeddings = True, + variational_recurrent_dropout = True, + bidirectional = False, + iterate_until_maxlength = False, + num_last_outputs = 1): + """Model function for LSTM.""" + + x_tensor, vocab_size, feature_input = _extract_vocab_size(X) + + y_tensor = tf.placeholder(tf.float32, shape=(None, y.shape[1]), name = "y") + dropout_tensor = tf.placeholder(tf.float32, name = "dropout") + + params_fit = {dropout_tensor : keep_prob_dropout} + params_predict = {dropout_tensor : 1} + + if iterate_until_maxlength: + # create a vector of correct shape and set to maxlen + seq_length = tf.reduce_sum(feature_input, 1) + seq_length = tf.cast(seq_length, tf.int32) + seq_length = feature_input.get_shape().as_list()[1] * tf.ones_like(seq_length) + else: + seq_length = sequence_length(feature_input) + + initializer_operations = [] + + embedded_words, _ = _init_embedding_layer(pretrained_embeddings_path, feature_input, + embedding_size, vocab_size, + params_fit, + params_predict, + trainable_embeddings, + initializer_operations) + + def create_multilayer_lstm(): + # build multiple layers of lstms + lstm_layers = [] + for hidden_layer_size in hidden_layers: + single_lstm_layer = tf.contrib.rnn.LSTMCell(hidden_layer_size, use_peepholes = True) + if variational_recurrent_dropout: + single_lstm_layer = tf.contrib.rnn.DropoutWrapper(single_lstm_layer, + input_keep_prob=1., + output_keep_prob=1., + state_keep_prob=dropout_tensor, + variational_recurrent=True, + dtype = tf.float32) + lstm_layers.append(single_lstm_layer) + stacked_lstm = tf.contrib.rnn.MultiRNNCell(lstm_layers) + return stacked_lstm + + forward_lstm = create_multilayer_lstm() + forward_state = forward_lstm.zero_state(tf.shape(embedded_words)[0], tf.float32) + + # bidirectional lstm? + ## we can discard the state after the batch is fully processed + if not bidirectional: + output_state, _ = tf.nn.dynamic_rnn(forward_lstm, embedded_words, initial_state = forward_state, sequence_length = seq_length) + else: + backward_lstm = create_multilayer_lstm() + backward_state = backward_lstm.zero_state(tf.shape(embedded_words)[0], tf.float32) + bidi_output_states, _ = tf.nn.bidirectional_dynamic_rnn(forward_lstm, backward_lstm, embedded_words, + initial_state_fw = forward_state, initial_state_bw = backward_state, + sequence_length = seq_length) + ## we concatenate the outputs of forward and backward rnn in accordance with Hierarchical Attention Networks + h1, h2 = bidi_output_states + output_state = tf.concat([h1, h2], axis = 2, name = "concat_bidi_output_states") + + # note that dynamic_rnn returns zero outputs after seq_length + if aggregate_output == "sum": + output_state = tf.reduce_sum(output_state, axis = 1) + elif aggregate_output == "average": + output_state = average_outputs(output_state, seq_length) + elif aggregate_output == "last": + # return output at last time step + output_state = extract_axis_1(output_state, seq_length - 1) + elif aggregate_output == "attention": + output_state = _word_attention(output_state) + elif aggregate_output == "oe-attention": + # perform attention over overeager outputs + output_state = tf.concat([extract_axis_1(output_state, seq_length - (num_last_outputs + 1 - i)) for i in range(num_last_outputs)], axis = 1) + _word_attention(output_state) + else: + raise ValueError("Aggregation method not implemented!") + + hidden_layer = tf.nn.dropout(output_state, dropout_tensor) + + return x_tensor, y_tensor, hidden_layer, params_fit, params_predict, initializer_operations +def cnn_fn(X, y, keep_prob_dropout = 0.5, embedding_size = 30, hidden_layers = [1000], + pretrained_embeddings_path = None, + trainable_embeddings = True, + dynamic_max_pooling_p = 1, + # these are set according to Kim's Sentence Classification + window_sizes = [3, 4, 5], + num_filters = 100): + """Model function for CNN.""" + + # x_tensor includes the max_index_column, feature_input doesnt. go on with feature_input, but return x_tensor for feed_dict + x_tensor, vocab_size, feature_input = _extract_vocab_size(X) + max_length = X.shape[1] - 1 + + y_tensor = tf.placeholder(tf.float32, shape=(None, y.shape[1]), name = "y") + dropout_tensor = tf.placeholder(tf.float32, name = "dropout") + + params_fit = {dropout_tensor : keep_prob_dropout} + params_predict = {dropout_tensor : 1} + + seq_length = sequence_length(feature_input) + seq_length = tf.cast(seq_length, tf.float32) + + initializer_operations = [] + embedded_words, embedding_size = _init_embedding_layer(pretrained_embeddings_path, + feature_input, embedding_size, + vocab_size, params_fit, + params_predict, + trainable_embeddings, + initializer_operations) + + # need to extend the number of dimensions here in order to use the predefined pooling operations, which assume 2d pooling + embedded_words = tf.expand_dims(embedded_words, -1) + + stride = [1, 1, 1, 1] + padding = "VALID" + + pooled_outputs = [] + for window_size in window_sizes: + filter_weights = tf.Variable(tf.random_normal([window_size, embedding_size, 1, num_filters], stddev=0.1)) + conv = tf.nn.conv2d(embedded_words, filter_weights, stride, padding) + bias = tf.Variable(tf.random_normal([num_filters])) + detector = tf.nn.relu(tf.nn.bias_add(conv, bias)) + + concatenated_pooled_chunks = dynamic_max_pooling(detector, seq_length, max_length, num_filters, window_size, dynamic_max_pooling_p = dynamic_max_pooling_p) + pooled_outputs.append(concatenated_pooled_chunks) + + concatenated_pools = tf.concat(pooled_outputs, 1) + num_filters_total = num_filters * len(window_sizes) * dynamic_max_pooling_p + hidden_layer = tf.reshape(concatenated_pools, [-1, num_filters_total]) + + hidden_layer = tf.nn.dropout(hidden_layer, dropout_tensor) + + return x_tensor, y_tensor, hidden_layer, params_fit, params_predict, initializer_operations + +def mlp_base_fn(X, y, keep_prob_dropout = 0.5, hidden_activation_function = tf.nn.relu): + """Model function for MLP-Soph.""" + # convert sparse tensors to dense + x_tensor = tf.placeholder(tf.float32, shape=(None, X.shape[1]), name = "x") + y_tensor = tf.placeholder(tf.float32, shape=(None, y.shape[1]), name = "y") + dropout_tensor = tf.placeholder(tf.float32, name = "dropout") + + params_fit = {dropout_tensor : keep_prob_dropout} + params_predict = {dropout_tensor : 1} + + # Connect the first hidden layer to input layer + # (features) with relu activation and add dropout + hidden_layer = tf.contrib.layers.fully_connected(x_tensor, 1000, activation_fn = hidden_activation_function) + hidden_dropout = tf.nn.dropout(hidden_layer, dropout_tensor) + + return x_tensor, y_tensor, hidden_dropout, params_fit, params_predict, [] + +# https://github.com/bioinf-jku/SNNs/blob/master/selu.py +def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, + noise_shape=None, seed=1337, name=None, training=False): + """Dropout to a value with rescaling.""" + + def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): + keep_prob = 1.0 - rate + x = ops.convert_to_tensor(x, name="x") + if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: + raise ValueError("keep_prob must be a scalar tensor or a float in the " + "range (0, 1], got %g" % keep_prob) + keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") + keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") + alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + if tensor_util.constant_value(keep_prob) == 1: + return x + + noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) + random_tensor = keep_prob + random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) + binary_tensor = math_ops.floor(random_tensor) + ret = x * binary_tensor + alpha * (1-binary_tensor) + + a = math_ops.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * math_ops.pow(alpha-fixedPointMean,2) + fixedPointVar))) + + b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) + ret = a * ret + b + ret.set_shape(x.get_shape()) + return ret + + with ops.name_scope(name, "dropout", [x]) as name: + return utils.smart_cond(training, + lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), + lambda: array_ops.identity(x)) + +# https://github.com/bioinf-jku/SNNs/blob/master/selu.py +def selu(x): + with ops.name_scope('elu') as scope: + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x)) + +def mlp_soph_fn(X, y, keep_prob_dropout = 0.5, embedding_size = 30, hidden_layers = [1000], self_normalizing = False, hidden_activation_function = tf.nn.relu, + standard_normal = False, batch_norm = True): + """Model function for MLP-Soph.""" + + # convert sparse tensors to dense + x_tensor = tf.placeholder(tf.float32, shape=(None, X.shape[1]), name = "x") + y_tensor = tf.placeholder(tf.float32, shape=(None, y.shape[1]), name = "y") + dropout_tensor = tf.placeholder(tf.float32, name = "dropout") + + params_fit = {dropout_tensor : keep_prob_dropout} + params_predict = {dropout_tensor : 1} + + # we need to have the input data scaled such they have mean 0 and variance 1 + if standard_normal: + scaled_input = tf_normalize(X, x_tensor) + else: + scaled_input = x_tensor + + # apply a look-up as described by the fastText paper + if embedding_size > 0: + lookup_table = tf.Variable(tf.truncated_normal([X.shape[1], embedding_size], mean=0.0, stddev=0.1)) + embedding_layer = tf.matmul(scaled_input, lookup_table) + else: + embedding_layer = scaled_input + + # Connect the embedding layer to the hidden layers + # (features) with relu activation and add dropout everywhere + hidden_layer = embedding_layer + for hidden_units in hidden_layers: + if not self_normalizing: + normalizer_fn = tf.layers.batch_normalization if batch_norm else None + hidden_layer = tf.contrib.layers.fully_connected(hidden_layer, hidden_units, activation_fn = hidden_activation_function, normalizer_fn = normalizer_fn) + hidden_layer = tf.nn.dropout(hidden_layer, dropout_tensor) + else: + hidden_layer = tf.contrib.layers.fully_connected(hidden_layer, hidden_units, + activation_fn=None, + weights_initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')) + hidden_layer = selu(hidden_layer) + # dropout_selu expects to be given the dropout rate instead of keep probability + hidden_layer = dropout_selu(hidden_layer, tf.constant(1, tf.float32) - dropout_tensor) + + return x_tensor, y_tensor, hidden_layer, params_fit, params_predict, [] + +def swish(x): + """ + Implementation of Swish: a Self-Gated Activation Function (https://arxiv.org/abs/1710.05941) + """ + return x * tf.sigmoid(x) + +def _transform_activation_function(func): + if func == "relu": + hidden_activation_function = tf.nn.relu + elif func == "tanh": + hidden_activation_function = tf.nn.tanh + elif func == "identity": + hidden_activation_function = tf.identity + elif func == "swish": + hidden_activation_function = swish + return hidden_activation_function + +def mlp_base(hidden_activation_function = "relu"): + """ + Returns a function that can be used as the get_model parameter for MultiLabelSKFlow so it trains Base-MLP. + """ + hidden_activation_function = _transform_activation_function(hidden_activation_function) + + return lambda X, y : mlp_soph_fn(X, y, keep_prob_dropout = 0.5, embedding_size = 0, + hidden_layers = [1000], self_normalizing = False, + hidden_activation_function=hidden_activation_function, + standard_normal = False, batch_norm = False) + +def mlp_soph(keep_prob_dropout, embedding_size, hidden_layers, self_normalizing, standard_normal, batch_norm, hidden_activation_function = "relu"): + """ + Returns a function that can be used as the get_model parameter for MultiLabelSKFlow so it trains MLP with the specified parameters. + """ + hidden_activation_function = _transform_activation_function(hidden_activation_function) + + return lambda X, y : mlp_soph_fn(X, y, keep_prob_dropout = keep_prob_dropout, embedding_size = embedding_size, + hidden_layers = hidden_layers, self_normalizing = self_normalizing, + hidden_activation_function=hidden_activation_function, + standard_normal = standard_normal, batch_norm = batch_norm) + +def cnn(keep_prob_dropout, embedding_size, hidden_layers, pretrained_embeddings_path, trainable_embeddings, dynamic_max_pooling_p, window_sizes, num_filters): + """ + Returns a function that can be used as the get_model parameter for MultiLabelSKFlow so it trains CNN with the specified configuration. + """ + return lambda X, y : cnn_fn(X, y, keep_prob_dropout = keep_prob_dropout, embedding_size = embedding_size, + hidden_layers = hidden_layers, pretrained_embeddings_path=pretrained_embeddings_path, + trainable_embeddings = trainable_embeddings, + dynamic_max_pooling_p = dynamic_max_pooling_p, + window_sizes = window_sizes, + num_filters = num_filters) + +def lstm(keep_prob_dropout, embedding_size, hidden_layers, pretrained_embeddings_path, trainable_embeddings, variational_recurrent_dropout, + bidirectional, aggregate_output, iterate_until_maxlength, num_last_outputs): + """ + Returns a function that can be used as the get_model parameter for MultiLabelSKFlow so it trains LSTM with the specified configuration. + """ + + return lambda X, y : lstm_fn(X, y, keep_prob_dropout = keep_prob_dropout, embedding_size = embedding_size, + hidden_layers = hidden_layers, pretrained_embeddings_path=pretrained_embeddings_path, + trainable_embeddings = trainable_embeddings, + variational_recurrent_dropout=variational_recurrent_dropout, + bidirectional = bidirectional, + aggregate_output = aggregate_output, + iterate_until_maxlength = iterate_until_maxlength, + num_last_outputs=num_last_outputs) + + +class BatchGenerator: + """ + Splits a training/test set into batches to enable mini-batch training. + + Parameters + ---------- + X: np.array, or csr_matrix + The entire dataset to split into batches. + y: np.array, or csr_matrix + The corresponding goldstandard of the entire dataset to split into batches. May be None if 'predict' = False + batch_size: int + Size of the resulting batches. + shuffle: bool + Whether to shuffle the dataset before splitting into batches (should be done at training time, shouldn't be done at prediction time). + predict: bool + Whether we are in prediction mode, i.e., there is no goldstandard. + """ + def __init__(self, X, y, batch_size, shuffle, predict): + self.X = X + self.y = y + self.number_of_batches = np.ceil(X.shape[0] / batch_size) + self.counter = 0 + self.sample_index = np.arange(X.shape[0]) + self.batch_size = batch_size + self.predict = predict + self.shuffle = shuffle + if shuffle: + np.random.shuffle(self.sample_index) + + + def _batch_generator(self): + + batch_index = self.sample_index[self.batch_size * self.counter:self.batch_size * (self.counter + 1)] + + X_batch = self.X[batch_index, :] + if sps.issparse(X_batch): + X_batch = X_batch.toarray() + + if not self.predict: + y_batch = self.y[batch_index].toarray() + self.counter += 1 + if self.counter == self.number_of_batches: + if self.shuffle: + np.random.shuffle(self.sample_index) + self.counter = 0 + if not self.predict: + return X_batch, y_batch + else: + return X_batch + + +class MultiLabelSKFlow(BaseEstimator): + """ + This is a wrapper class for TensorFlow, so it adheres to the fit/predict naming conventions of sk-learn. + This class handles the output layer, mini-batch learning, early stopping, threshold optimization on the validation set, and the neural metalabeler. + + The concrete TensorFlow model up to the last hidden layer can be specified in terms of the 'get_model' function. + This function in turn has to accept the dataset X (np.array, or csr_matrix), and the goldstandard y (csr_matrix). + Moreover, get_model() is expected to return the following components: + + x_tensor: tf.placeholder + Used to pass input data to the model at training and test time. + y_tensor: tf.placeholder + Used by to pass the ground truth to the model during training. + last_layer: tf.Tensor + The TensorFlow computation graph from input layer to last hidden layer of the implemented neural network. + params_fit: dictionary + Parameters to be added to the feed dictionary for training (e.g., keep_probability_placeholder -> 0.5) + params_predict: dictionary + Parameters to be added to the feed dictionary + at prediction time (e.g., keep_probability_placeholder -> 1.0) + initializer_operations: list of (tf.Tensor, dictionary) + A list of + pairs consisting of operations for initializing variables (e.g., + embedding tables) before training starts, and the feed dictionary with data to execute + the initialize operation. + + Moreover, training can be controlled by the following parameters: + + Parameters + ---------- + batch_size: int, default = 5 + Batch size to use during training and at prediction time. + num_epochs: int, default = 10 + Number of iterations over the dataset during training. + get_model: function, default = mlp_base() + The function that returns the underlying neural network up to the last hidden layer. See above description. + threshold: float, default = 0.2 + Fixed threshold to use if "optimize_threshold" = False, or starting threshold when "optimize_threshold" = True. + learning_rate: float, default = 0.1 + Initial learning rate to use for Adam. + patience, int, default = 5 + Number of non-improving evaluations on the validation set before terminating training. + validation_metric, function true_values, predicted_values -> float, default = f1_score + The metric that is used for evaluating prediction on the validation set. + optimize_threshold, boolean, default = True + Determines whether the threshold is optimized on a validation set. + threshold_window, array-like of float, default = np.linspace(-0.03, 0.03, num=7) + An array of floats that are interpreted as offset from the current threshold value. When optimizing the threshold, + each of these offsets is added to the current threshold and the validation performance is assessed. Afterwards, the + threshold is set to the value that has yielded the best score. + tf_model_path, str, default = ".tmp_best_models" + A path to the folder where the weights of the best model are saved, so it can be loaded at prediction time. + num_steps_before_validation, int, default = None + Determines the number of batches between two performance evaluations on the validation set. If set to None, this number is determined from the size of + the training set, i.e., it is set to one epoch. + hidden_activation_function, TensorFlow operation, default = tf.nn.relu + The activation function to apply after the bottleneck layer. + bottleneck_layers, list of int, default = None + As many layers as there are elements in this list are injected before the output layer. Element i specifies the number of units + in bottleneck layer i. + hidden_keep_prob, float, default = 0.5 + Specifies the keep probability of dropout to apply after each bottleneck layer. + gpu_memory_fraction, float, default = 1. + Specifies how much of the RAM of each available GPU TensorFlow may reserve. + meta_labeler_phi, str, default = None + Determines which 'phi' function from the definition of Neural MetaLabeler we use: "content", "score", or None. If none is used, MetaLabeler is not + used at all. If "content" is used, the prediction is based on the output of the last hidden layer from the underlying neural network (given by get_model). + If "score" is used, the prediction is based on the probabilities given by the output layer. + meta_labeler_alpha, float, default = 0.1 + The label-classification objective is weighted by (1 - alpha), and the objective of predicting the number of labels is weighted by alpha. + meta_labeler_min_labels, int, default = 1 + Specifies the smallest possible number of labels that can be predicted by Neural MetaLabeler. + meta_labeler_max_labels, int, default = None + Specifies the largest possible number of labels that can be predicted by Neural MetaLabeler. If set to None, the maximum number of labels is determined from + the training set. + """ + + def __init__(self, batch_size = 5, num_epochs = 10, get_model = mlp_base(), threshold = 0.2, learning_rate = 0.1, patience = 5, + validation_metric = lambda y1, y2 : f1_score(y1, y2, average = "samples"), + optimize_threshold = True, + threshold_window = np.linspace(-0.03, 0.03, num=7), + tf_model_path = ".tmp_best_models", + num_steps_before_validation = None, + hidden_activation_function = tf.nn.relu, + bottleneck_layers = None, + hidden_keep_prob = 0.5, + gpu_memory_fraction = 1., + meta_labeler_phi = None, + meta_labeler_alpha = 0.1, + meta_labeler_min_labels = 1, + meta_labeler_max_labels = None): + """ + + """ + + self.get_model = get_model + + # enable early stopping on validation set + self.validation_data_position = None + self.num_steps_before_validation = num_steps_before_validation + + # configurations for bottleneck layers + self.hidden_activation_function = hidden_activation_function + self.bottleneck_layers = bottleneck_layers + self.hidden_keep_prob = hidden_keep_prob + + # configuration for meta-labeler + self.meta_labeler_phi = meta_labeler_phi + self.meta_labeler_alpha = meta_labeler_alpha + self.num_label_binarizer = None + self.meta_labeler_max_labels = meta_labeler_max_labels + self.meta_labeler_min_labels = meta_labeler_min_labels + + # used by this class + self.validation_metric = validation_metric + self.optimize_threshold = optimize_threshold + self.threshold_window = threshold_window + self.patience = patience + self.batch_size = batch_size + self.num_epochs = num_epochs + self.threshold = threshold + if learning_rate is None: + self.learning_rate = self.batch_size / 512 * 0.01 + else: + self.learning_rate = learning_rate + + # path to save the tensorflow model to + self.TF_MODEL_PATH = tf_model_path + self._save_model_path = self._get_save_model_path() + + # determine how much of gpu to use + self.gpu_memory_fraction = gpu_memory_fraction + + def _get_save_model_path(self): + TMP_FOLDER = self.TF_MODEL_PATH + if not os.path.exists(TMP_FOLDER): + os.makedirs(TMP_FOLDER) + return TMP_FOLDER + "/best-model-" + self.get_model.__name__ + str(datetime.now()) + + def _calc_num_steps(self, X): + return int(np.ceil(X.shape[0] / self.batch_size)) + + + def _predict_batch(self, X_batch): + feed_dict = {self.x_tensor: X_batch} + feed_dict.update(self.params_predict) + + if self.meta_labeler_phi is None: + predictions = self.session.run(self.predictions, feed_dict = feed_dict) + else: + predictions = self.session.run([self.predictions, self.meta_labeler_prediction], feed_dict = feed_dict) + + return predictions + + def _make_binary_decision(self, predictions): + if self.meta_labeler_phi is None: + y_pred = predictions > self.threshold + else: + predictions, meta_labeler_predictions = predictions + max_probability_cols = np.argmax(meta_labeler_predictions, axis = 1) + max_probability_indices = tuple(np.indices([meta_labeler_predictions.shape[0]]))+(max_probability_cols,) + meta_labeler_predictions = np.zeros_like(meta_labeler_predictions) + meta_labeler_predictions[max_probability_indices] = 1 + meta_labeler_predictions = self.num_label_binarizer.inverse_transform(meta_labeler_predictions, 0) + y_pred = np.zeros_like(predictions) + for i in range(predictions.shape[0]): + num_labels_for_sample = meta_labeler_predictions[i] + top_indices = (-predictions[i,:]).argsort()[:num_labels_for_sample] + y_pred[i,top_indices] = 1 + + return csr_matrix(y_pred) + + def _compute_validation_score(self, session, X_val_batch, y_val_batch): + + feed_dict = {self.x_tensor: X_val_batch} + feed_dict.update(self.params_predict) + + if self.validation_metric == "val_loss": + return session.run(self.loss, feed_dict = feed_dict) + + elif callable(self.validation_metric): + predictions = self._predict_batch(X_val_batch) + y_pred = self._make_binary_decision(predictions) + if self.optimize_threshold: + return self.validation_metric(y_val_batch, y_pred), predictions + else: + return self.validation_metric(y_val_batch, y_pred) + + def _print_progress(self, epoch, batch_i, steps_per_epoch, avg_validation_score, best_validation_score, total_loss, meta_loss, label_loss): + + progress_string = 'Epoch {:>2}/{:>2}, Batch {:>2}/{:>2}, Loss: {:0.4f}, Validation-Score: {:0.4f}, Best Validation-Score: {:0.4f}' + format_parameters = [epoch + 1, self.num_epochs, batch_i + 1, steps_per_epoch, + total_loss, avg_validation_score, best_validation_score] + if self.meta_labeler_phi is None: + progress_string += ', Threshold: {:0.2f}' + format_parameters.append(self.threshold) + + else: + progress_string += ', Label-Loss: {:0.4f}, Meta-Loss: {:0.4f}' + format_parameters.extend([label_loss, meta_loss]) + + progress_string = progress_string.format(*format_parameters) + print(progress_string, end='\r') + + def _num_labels_discrete(self, y, min_number_labels = 1, max_number_labels = None): + """ + Counts for each row in 'y' how many of the columns are set to 1. Outputs the result in turn as a binary indicator matrix where + the columns 0, ..., m correspond to 'min_number_labels', 'min_number_labels' + 1, ..., 'max_number_labels'. + + Parameters + ---------- + y: (sparse) numpy array of shape [n_samples, n_classes] + An indicator matrix denoting which classes are assigned to a sample (multiple columns per row may be 1) + min_number_labels: int, default=1 + Minimum number of labels each sample has to have. If a sample has less than 'min_number_labels' assigned, + the corresponding output is set to 'min_number_labels'. + max_number_labels: int, default=None + Maximum number of labels each sample has to have. If a sample has more than 'min_number_labels' assigned, + the corresponding output is set to 'max_number_labels'. If 'max_number_labels' is None, it is set to the max number found + in y. + Returns + --------- + num_samples_y: (sparse) numpy array of shape [n_samples, max_number_samples - min_number_samples + 1] + """ + + num_samples_y = np.array(np.sum(y, axis = 1)) + num_samples_y = num_samples_y.reshape(-1) + num_samples_y[num_samples_y < min_number_labels] = min_number_labels + + if max_number_labels is None: + max_number_labels = np.max(num_samples_y) + + num_samples_y[num_samples_y > max_number_labels] = max_number_labels + + # 'fit' method calls this + if self.num_label_binarizer is None: + self.num_label_binarizer = LabelBinarizer() + self.num_label_binarizer.fit(num_samples_y) + + indicator_matrix_num_labels = self.num_label_binarizer.transform(num_samples_y) + return indicator_matrix_num_labels + + + def fit(self, X, y): + self.y = y + + val_pos = self.validation_data_position + + if val_pos is not None: + X_train, y_train, X_val, y_val = X[:val_pos, :], y[:val_pos,:], X[val_pos:, :], y[val_pos:, :] + + validation_batch_generator = BatchGenerator(X_val, y_val, self.batch_size, False, False) + validation_predictions = self._calc_num_steps(X_val) + steps_per_epoch = self._calc_num_steps(X_train) + + # determine after how many batches to perform validation + num_steps_before_validation = self.num_steps_before_validation + if self.num_steps_before_validation is None: + num_steps_before_validation = steps_per_epoch + num_steps_before_validation = int(min(steps_per_epoch, num_steps_before_validation)) + else: + steps_per_epoch = self._calc_num_steps(X) + X_train = X + y_train = y + + # Remove previous weights, bias, inputs, etc.. + tf.reset_default_graph() + tf.set_random_seed(1337) + + # get_model has to return a + self.x_tensor, self.y_tensor, self.last_layer, self.params_fit, self.params_predict, initializer_operations = self.get_model(X, y) + + # add bottleneck layer + if self.bottleneck_layers is not None: + bottleneck_dropout_tensor = tf.placeholder(tf.float32, name = "bottleneck_dropout") + self.params_fit.update({bottleneck_dropout_tensor : self.hidden_keep_prob}) + self.params_predict.update({bottleneck_dropout_tensor : 1}) + for units in self.bottleneck_layers: + self.last_layer = tf.contrib.layers.fully_connected(self.last_layer, units, activation_fn = self.hidden_activation_function) + self.last_layer = tf.nn.dropout(self.last_layer, bottleneck_dropout_tensor) + + + # Name logits Tensor, so that is can be loaded from disk after training + #logits = tf.identity(logits, name='logits') + logits = tf.contrib.layers.linear(self.last_layer, + num_outputs=y.shape[1]) + + + # Loss and Optimizer + losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.y_tensor) + loss = tf.reduce_sum(losses, axis = 1) + self.label_loss = tf.reduce_mean(loss, axis = 0) + + # prediction + self.predictions = tf.sigmoid(logits) + + if self.meta_labeler_phi is not None: + + # compute target of meta labeler + y_num_labels = self._num_labels_discrete(y_train, min_number_labels=self.meta_labeler_min_labels,max_number_labels=self.meta_labeler_max_labels) + y_num_labels_tensor = tf.placeholder(tf.float32, shape=(None, y_num_labels.shape[1]), name = "y_num_labels") + + # compute logits of meta labeler + if self.meta_labeler_phi == "content": + meta_logits = tf.contrib.layers.linear(self.last_layer, num_outputs=y_num_labels.shape[1]) + elif self.meta_labeler_phi == "score": + meta_logits = tf.contrib.layers.linear(self.predictions, num_outputs=y_num_labels.shape[1]) + + # compute loss of meta labeler + meta_labeler_loss = tf.nn.softmax_cross_entropy_with_logits(labels = y_num_labels_tensor, logits = meta_logits) + self.meta_labeler_loss = tf.reduce_mean(meta_labeler_loss, axis = 0) + + # compute prediction of meta labeler + self.meta_labeler_prediction = tf.nn.softmax(meta_logits) + + # add meta labeler loss to labeling loss + self.loss = (1 - self.meta_labeler_alpha) * self.label_loss + self.meta_labeler_alpha * self.meta_labeler_loss + else: + self.loss = self.label_loss + + # optimize + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss) + + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_memory_fraction) + session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) + self.session = session + # Initializing the variables + session.run(tf.global_variables_initializer()) + for (init_op, init_op_feed_dict) in initializer_operations: + session.run(init_op, feed_dict = init_op_feed_dict) + + batch_generator = BatchGenerator(X_train, y_train, self.batch_size, True, False) + # Training cycle + objective = 1 if self.validation_metric == "val_loss" else -1 + avg_validation_score = math.inf * objective + best_validation_score = math.inf * objective + epochs_of_no_improvement = 0 + most_consecutive_epochs_with_no_improvement = 0 + batches_counter = 0 + epoch = 0 + stop_early = False + while epoch < self.num_epochs and not stop_early: + + if val_pos is not None and epochs_of_no_improvement == self.patience: + break + + # Loop over all batches + for batch_i in range(steps_per_epoch): + X_batch, y_batch = batch_generator._batch_generator() + feed_dict = {self.x_tensor: X_batch, self.y_tensor: y_batch} + feed_dict.update(self.params_fit) + + if self.meta_labeler_phi is not None: + feed_dict.update({y_num_labels_tensor : self._num_labels_discrete(y_batch)}) + + session.run(optimizer, feed_dict = feed_dict) + + # overwrite parameter values for prediction step + feed_dict.update(self.params_predict) + + # compute losses to track progress + if self.meta_labeler_phi is not None: + total_loss, label_loss, meta_loss = session.run([self.loss, self.label_loss, self.meta_labeler_loss], feed_dict = feed_dict) + else: + total_loss = session.run(self.loss, feed_dict = feed_dict) + label_loss, meta_loss = None, None + + batches_counter += 1 + is_last_epoch = epoch == self.num_epochs - 1 + is_last_batch_in_epoch = batch_i == steps_per_epoch - 1 + # calculate validation loss at end of epoch if early stopping is on + if val_pos is not None and (batches_counter == num_steps_before_validation + or (is_last_epoch and is_last_batch_in_epoch)): + + batches_counter = 0 + + validation_scores = [] + weights = [] + + # save predictions so we can optimize threshold later + val_predictions = np.zeros((X_val.shape[0], self.y.shape[1])) + for i in range(validation_predictions): + X_val_batch, y_val_batch = validation_batch_generator._batch_generator() + weights.append(X_val_batch.shape[0]) + + if self.optimize_threshold: + batch_val_score, val_predictions[i * self.batch_size:(i+1) * self.batch_size, :] = self._compute_validation_score(session, X_val_batch, y_val_batch) + else: + batch_val_score = self._compute_validation_score(session, X_val_batch, y_val_batch) + validation_scores.append(batch_val_score) + avg_validation_score = np.average(np.array(validation_scores), weights = np.array(weights)) + + if self.optimize_threshold: + best_score = -1 * math.inf + best_threshold = self.threshold + for t_diff in self.threshold_window: + t = self.threshold + t_diff + score = self.validation_metric(y_val, csr_matrix(val_predictions > t)) + if score > best_score: + best_threshold = t + best_score = score + + is_better_score = avg_validation_score < best_validation_score if objective == 1 else avg_validation_score > best_validation_score + if is_better_score: + # save model + # Save model for prediction step + best_validation_score = avg_validation_score + saver = tf.train.Saver() + saver.save(session, self._save_model_path) + + if most_consecutive_epochs_with_no_improvement < epochs_of_no_improvement: + most_consecutive_epochs_with_no_improvement = epochs_of_no_improvement + epochs_of_no_improvement = 0 + + # save the threshold at best model, too. + if self.optimize_threshold: + self.threshold = best_threshold + else: + epochs_of_no_improvement += 1 + if epochs_of_no_improvement > self.patience: + print("No improvement in validation loss for", self.patience, "epochs. Stopping early.") + stop_early = True + break + + # print progress + self._print_progress(epoch, batch_i, steps_per_epoch, avg_validation_score, best_validation_score, total_loss, meta_loss, label_loss) + + epoch += 1 + + print('') + + print("Training of TensorFlow model finished!") + print("Longest sequence of epochs of no improvement:", most_consecutive_epochs_with_no_improvement) + + + def predict(self, X): + + session = self.session + #loaded_graph = tf.Graph() + if self.validation_data_position: + # Load model + loader = tf.train.import_meta_graph(self._save_model_path + '.meta') + loader.restore(self.session, self._save_model_path) + + prediction = np.zeros((X.shape[0], self.y.shape[1])) + batch_generator = BatchGenerator(X, None, self.batch_size, False, True) + prediction_steps = self._calc_num_steps(X) + for i in range(prediction_steps): + X_batch = batch_generator._batch_generator() + preds = self._predict_batch(X_batch) + binary_decided_preds = self._make_binary_decision(preds) + prediction[i * self.batch_size:(i+1) * self.batch_size, :] = binary_decided_preds.todense() + + result = csr_matrix(prediction) + + # close the session, since no longer needed + session.close() + return result + + diff --git a/Code/lucid_ml/default_searchspace b/Code/lucid_ml/default_searchspace new file mode 100644 index 0000000..afc565e --- /dev/null +++ b/Code/lucid_ml/default_searchspace @@ -0,0 +1,25 @@ +# the structure should be like this: +# +# for random search, specify the following format: +# ,,, +# ... +# ,,, +# example: +#learning_rate,-9.21,-2.3,loguniform +#dropout,0.05,0.95,uniform +# +# for bayesian optimization, specify in the following format: +# ,,,, +# ... +# ,,,, +# example: +learning_rate,0.0001,1,0.1,0.01,0.111 +dropout,0.05,0.95,0.05,0.5,0.555 +# +# for grid search, specify in the following format: +# ,,,.., +# ... +# ,,,..., +# example: +#learning_rate,float,0.01,0.005,0.001 +#dropout,float,0.5,0.05,0.95 \ No newline at end of file diff --git a/Code/lucid_ml/file_paths.json b/Code/lucid_ml/file_paths.json index b7d43ed..09d27c4 100644 --- a/Code/lucid_ml/file_paths.json +++ b/Code/lucid_ml/file_paths.json @@ -1,12 +1,12 @@ { - "example-titles": { - "X": "../../Resources/example/example-titles.tsv", - "y": "../../Resources/example/example-goldstandard.tsv", - "thes": "../../Resources/example/stw.json" + "econbiz": { + "format" : "combined", + "X" : "../../Resources/econbiz.csv", + "label_delimiter" : "\t" }, - "example-fulltext": { - "X": "../../Resources/example/example-fulltext", - "y": "../../Resources/example/example-goldstandard.tsv", - "thes": "../../Resources/example/stw.json" + "pubmed": { + "format" : "combined", + "X" : "../../Resources/pubmed.csv", + "label_delimiter" : "\t" } } diff --git a/Code/lucid_ml/run.py b/Code/lucid_ml/run.py index a1f9eff..c2dd1a3 100644 --- a/Code/lucid_ml/run.py +++ b/Code/lucid_ml/run.py @@ -7,6 +7,7 @@ from classifying.neural_net import MLP, ThresholdingPredictor from classifying.stack_lin_reg import LinRegStack +from rdflib.plugins.parsers.ntriples import validate os.environ['OMP_NUM_THREADS'] = '1' # For parallelization use n_jobs, this gives more control. import numpy as np @@ -16,10 +17,11 @@ from utils.processify import processify +from itertools import product + warnings.filterwarnings("ignore", category=UserWarning) -from sklearn import model_selection -from sklearn.model_selection import ShuffleSplit +from sklearn.model_selection import KFold, ShuffleSplit # from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import TfidfTransformer @@ -30,35 +32,47 @@ from sklearn.pipeline import FeatureUnion, make_pipeline, Pipeline from sklearn.preprocessing import MultiLabelBinarizer from sklearn.svm import LinearSVC +import scipy.sparse as sps + +# imports for hyperparameter optimization +from bayes_opt import BayesianOptimization +from hyperopt import fmin, rand, hp +from sklearn.gaussian_process.kernels import Matern from classifying.br_kneighbor_classifier import BRKNeighborsClassifier from classifying.kneighbour_l2r_classifier import KNeighborsL2RClassifier from classifying.nearest_neighbor import NearestNeighbor from classifying.rocchioclassifier import RocchioClassifier from classifying.stacked_classifier import ClassifierStack +from classifying.tensorflow_models import MultiLabelSKFlow, mlp_base, mlp_soph, cnn, lstm from utils.Extractor import load_dataset from utils.metrics import hierarchical_f_measure, f1_per_sample -from utils.nltk_normalization import NltkNormalizer, word_regexp +from utils.nltk_normalization import NltkNormalizer, word_regexp, character_regexp from utils.persister import Persister from weighting.SpreadingActivation import SpreadingActivation, BinarySA, OneHopActivation from weighting.synset_analysis import SynsetAnalyzer from weighting.bm25transformer import BM25Transformer from weighting.concept_analysis import ConceptAnalyzer from weighting.graph_score_vectorizer import GraphVectorizer +from utils.text_encoding import TextEncoder +### SET LOGGING logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') -def run(options): +def _build_features(options): DATA_PATHS = json.load(options.key_file) VERBOSE = options.verbose persister = Persister(DATA_PATHS, options) if options.persist and persister.is_saved(): X, Y, tr = persister.read() - if VERBOSE: print("Y = " + str(Y.shape)) else: # --- LOAD DATA --- - X_raw, Y_raw, tr = load_dataset(DATA_PATHS, options.data_key, options.fulltext) + fold_list = None + if options.fixed_folds: + X_raw, Y_raw, tr, fold_list = load_dataset(DATA_PATHS, options.data_key, options.fulltext, fixed_folds=True) + else: + X_raw, Y_raw, tr, _ = load_dataset(DATA_PATHS, options.data_key, options.fulltext, fixed_folds=False) if options.toy_size < 1: if VERBOSE: print("Just toying with %d%% of the data." % (options.toy_size * 100)) zipped = list(zip(X_raw, Y_raw)) @@ -71,18 +85,38 @@ def run(options): mlb = MultiLabelBinarizer(sparse_output=True, classes=[i[1] for i in sorted( tr.index_nodename.items())] if options.hierarch_f1 else None) Y = mlb.fit_transform(Y_raw) - if VERBOSE: print("Y = " + str(Y.shape)) # --- EXTRACT FEATURES --- input_format = 'filename' if options.fulltext else 'content' - concept_analyzer = SynsetAnalyzer().analyze if options.synsets \ - else ConceptAnalyzer(tr.thesaurus, input=input_format, persist=options.persist and options.concepts, - persist_dir=options.persist_to, repersist=options.repersist, - file_path=DATA_PATHS[options.data_key]['X']).analyze - terms = CountVectorizer(input=input_format, stop_words='english', binary=options.binary, - token_pattern=word_regexp) - concepts = CountVectorizer(input=input_format, analyzer=concept_analyzer, binary=options.binary, - vocabulary=tr.nodename_index if not options.synsets else None) + if options.concepts: + if tr is None: + raise ValueError("Unable to extract concepts, since no thesaurus is given!") + concept_analyzer = SynsetAnalyzer().analyze if options.synsets \ + else ConceptAnalyzer(tr.thesaurus, input=input_format, persist=options.persist and options.concepts, + persist_dir=options.persist_to, repersist=options.repersist, + file_path=DATA_PATHS[options.data_key]['X']).analyze + concepts = CountVectorizer(input=input_format, analyzer=concept_analyzer, binary=options.binary, + vocabulary=tr.nodename_index if not options.synsets else None, + ngram_range=(1, options.ngram_limit)) + + # pick max_features from each ngram-group + if options.max_features is not None and options.ngram_limit > 1: + ngram_vectorizers = [] + for i in range(1, options.ngram_limit + 1): + i_grams = CountVectorizer(input=input_format, stop_words='english', binary=options.binary, + token_pattern=word_regexp, max_features=options.max_features, + ngram_range=(i, i)) + ngram_vectorizers.append(i_grams) + terms = FeatureUnion([(str(i) + "_gram", t) for i, t in enumerate(ngram_vectorizers)]) + else: + terms = CountVectorizer(input=input_format, stop_words='english', binary=options.binary, + token_pattern=word_regexp, max_features=options.max_features, + ngram_range=(1, options.ngram_limit)) + if options.charngrams: + character_ngrams = CountVectorizer(input=input_format, binary=options.binary, + token_pattern=character_regexp, max_features=options.max_features, + ngram_range=(1, options.char_ngram_limit), + analyzer = 'char_wb') if options.hierarchical: hierarchy = tr.nx_graph @@ -109,15 +143,24 @@ def run(options): activation = SpreadingActivation(tr.nx_graph, firing_threshold=1.0, decay=0.25, weighting=None) concepts = make_pipeline(concepts, activation) + features = [] if options.graph_scoring_method: - extractor = GraphVectorizer(method=options.graph_scoring_method, analyzer=concept_analyzer - if options.concepts else NltkNormalizer().split_and_normalize) - elif options.terms and (options.concepts or options.synsets): - extractor = FeatureUnion([("terms", terms), ("concepts", concepts)]) - elif options.terms: - extractor = terms - else: - extractor = concepts + features.append(("graph_vectorizer", GraphVectorizer(method=options.graph_scoring_method, analyzer=concept_analyzer + if options.concepts else NltkNormalizer().split_and_normalize))) + if options.terms: + features.append(("terms", terms)) + if options.concepts: + features.append(("concets", concepts)) + if options.charngrams: + features.append(("char_ngrams", character_ngrams)) + if options.onehot: + features.append(("onehot", TextEncoder(input_format = "filename" if options.fulltext else "content", + max_words=options.max_features, pretrained = options.pretrained_embeddings, + pad_special_symbol=options.pad_special_symbol))) + if len(features) == 0: + raise ValueError("No feature representation specified!") + + extractor = FeatureUnion(features) if VERBOSE: print("Extracting features...") if VERBOSE > 1: start_ef = default_timer() @@ -125,15 +168,20 @@ def run(options): if VERBOSE > 1: print(default_timer() - start_ef) if options.persist: persister.persist(X, Y, tr) + + return X, Y, extractor, mlb, fold_list, X_raw, Y_raw, tr +def _print_feature_info(X, options): + VERBOSE = options.verbose if VERBOSE: - print("X = " + repr(X)) - print("Vocabulary size: {}".format(X.shape[1])) + print("Feature size: {}".format(X.shape[1])) print("Number of documents: {}".format(X.shape[0])) - print("Mean distinct words per document: {}".format(X.count_nonzero() / - X.shape[0])) - words = X.sum(axis=1) - print("Mean word count per document: {} ({})".format(words.mean(), words.std())) + # these printouts only make sense if we have BoW representation + if sps.issparse(X): + print("Mean distinct words per document: {}".format(X.count_nonzero() / + X.shape[0])) + words = X.sum(axis=1) + print("Mean word count per document: {} ({})".format(words.mean(), words.std())) if VERBOSE > 1: X_tmp = X.todense() @@ -152,8 +200,22 @@ def run(options): # n_iter = np.ceil(10**6 / (X.shape[0] * 0.9)) # print("Dynamic n_iter = %d" % n_iter) - - + +def _print_label_info(Y, VERBOSE): + if VERBOSE: + print("Y = " + str(Y.shape)) + y_sum = Y.sum(axis = 0) + for i in range(1, 5): + print("Number of labels assigned more than", i, "times:" , np.sum(y_sum > i)) + + # compute avg number of labels per document + sum_of_labels_per_document = Y.sum(axis = 1) + print("Average number of labels per document:" , np.mean(sum_of_labels_per_document)) + + # compute avg number of documents per label + print("Average number of documents per label:" , np.mean(y_sum)) + +def _check_interactive(options, X, Y, extractor, mlb): if options.interactive: print("Please wait...") clf = create_classifier(options, Y.shape[1]) # --- INTERACTIVE MODE --- @@ -171,26 +233,103 @@ def run(options): exit(1) exit(0) - if VERBOSE: print("Performing %d-fold cross-validation..." % (options.folds if options.cross_validation else 1)) +def _build_folds(options, fold_list): + validation_set_indices = None + + if options.cross_validation: + kf = KFold(n_splits=options.folds, shuffle=True) + elif options.fixed_folds: + fixed_folds = [] + + # TODO: we assume 10 normal folds and 1 folds with extra samples. need to generalize + basic_folds = range(10) + + # we assume the extra data to be in the last fold + # TODO: currently we assume 10 folds (+1 extra) + extra_data = [index for index,x in enumerate(fold_list) if x == 10] + + validation_set_indices = [] + for i in range(options.folds): + + training_fold = [index for index,x in enumerate(fold_list) if x in basic_folds and x != i] + + if options.validation_size > 0: + # separate validation from training set here, and rejoin later if appropriate + num_validation_samples = int(len(training_fold) * options.validation_size) + validation_set_indices.append(training_fold[:num_validation_samples]) + training_fold = training_fold[num_validation_samples:] + + # add more training data from extra samples + if options.extra_samples_factor > 1: + num_extra_samples = int(min((options.extra_samples_factor - 1) * len(training_fold), len(extra_data))) + training_fold += extra_data[:num_extra_samples] + + test_fold = [index for index,x in enumerate(fold_list) if x == i] + fixed_folds.append((training_fold, test_fold)) + + # helper class to conform sklearn's model_selection structure + class FixedFoldsGenerator(): + def split(self, X): + return fixed_folds + + kf = FixedFoldsGenerator() + + else: + kf = ShuffleSplit(test_size=options.test_size, n_splits = 1) + + return kf, validation_set_indices +def _run_experiment(X, Y, kf, validation_set_indices, mlb, X_raw, Y_raw, tr, options): + VERBOSE = options.verbose + scores = defaultdict(list) if options.plot: all_f1s = [] # --- CROSS-VALIDATION --- scores = defaultdict(list) if options.cross_validation: - kf = model_selection.KFold(X.shape[0], n_folds=options.folds, shuffle=True) + kf = KFold(X.shape[0], n_folds=options.folds, shuffle=True) else: kf = ShuffleSplit(X.shape[0], test_size=options.test_size, n_iter=1) for train, test in kf: if VERBOSE: print("=" * 80) X_train, X_test, Y_train, Y_test = X[train], X[test], Y[train], Y[test] + + clf = create_classifier(options, Y_train.shape[1]) # --- INTERACTIVE MODE --- + + # extract a validation set and inform the classifier where to find it + if options.validation_size > 0: + # if we don't have fixed folds, we may pick the validation set randomly + if options.cross_validation or options.one_fold: + train, val = next(ShuffleSplit(test_size=options.validation_size, n_splits = 1).split(X_train)) + X_train, X_val, Y_train, Y_val = X_train[train], X_train[val], Y_train[train], Y_train[val] + elif options.fixed_folds: + X_train = X_train + X_val = X[validation_set_indices[iteration]] + Y_val = Y[validation_set_indices[iteration]] + + # put validation data at the end of training data and tell classifier the position where they start, if it is able + _, estimator = clf.steps[-1] + if hasattr(estimator, 'validation_data_position'): + estimator.validation_data_position = X_train.shape[0] + else: + raise ValueError("Validation size given although the estimator has no 'validation_data_position' property!") + + if sps.issparse(X): + X_train = sps.vstack((X_train, X_val)) + else: + X_train = np.vstack((X_train, X_val)) + + if sps.issparse(Y): + Y_train = sps.vstack((Y_train, Y_val)) + else: + Y_train = np.vstack((Y_train, Y_val)) # mlp doesn't seem to like being stuck into a new process... - if options.debug or options.clf_key in {'mlp', 'mlpthr'}: - Y_pred, Y_train_pred = fit_predict(X_test, X_train, Y_train, options, tr) + if options.debug or options.clf_key in {'mlp', 'mlpthr', 'mlpsoph', "cnn", "mlpbase", "lstm"}: + Y_pred, Y_train_pred = fit_predict(X_test, X_train, Y_train, options, tr, clf) else: - Y_pred, Y_train_pred = fit_predict_new_process(X_test, X_train, Y_train, options, tr) + Y_pred, Y_train_pred = fit_predict_new_process(X_test, X_train, Y_train, options, tr, clf) if options.training_error: scores['train_f1_samples'].append(f1_score(Y_train, Y_train_pred, average='samples')) @@ -244,9 +383,166 @@ def run(options): return results +def _update_options(options, **parameters): + for param_name, param_value in parameters.items(): + print("In automatic optimization trying parameter:", param_name, "with value", param_value) + + try: + setattr(options, param_name, param_value) + except AttributeError: + print("Can't find parameter ", param_name, "so we'll not use it.") + continue + + return options + + +def _make_space(options): + + space = {} + inits = {} + with open(options.optimization_spaces) as optimization_file: + for line in optimization_file: + + # escape comments + if line.startswith("#"): + continue + + line = line.strip() + info = line.split(",") + param_name = info[0] + + if options.optimization == "random": + left_bound, right_bound = float(info[1]), float(info[2]) + + param_type = info[3] + + try: + param_type = getattr(hp, param_type) + except AttributeError: + print("hyperopt has no attribute", param_type) + continue + + space[param_name] = param_type(param_name, left_bound, right_bound) + elif options.optimization == "bayesian": + left_bound, right_bound = float(info[1]), float(info[2]) + + init_values = list(map(float, info[3:])) + num_init_vals = len(init_values) + inits[param_name] = init_values + space[param_name] = (left_bound, right_bound) + + elif options.optimization == "grid": + + param_type = info[1] + def get_cast_func(some_string_type): + cast_func = None + if some_string_type == "int": + cast_func = int + elif some_string_type == "float": + cast_func = float + elif some_string_type == "string": + cast_func = str + elif some_string_type == "bool": + cast_func = bool + return cast_func + + cast_func = get_cast_func(param_type) + if cast_func is None: + if param_type.startswith("list"): + # determine type in list + list_type = get_cast_func(param_type.split("-")[1]) + + # assume they are seperated by semicolon + def extract_items(list_string): + return [list_type(x) for x in list_string.split(";")] + + cast_func = extract_items + + + # all possible values + space[param_name] = list(map(cast_func, info[2:])) + + if options.optimization == "bayesian": + return space, inits, num_init_vals + else: + return space + +def _all_option_combinations(space): + + names = [name for name, _ in space.items()] + values = [values for _, values in space.items()] + + val_combinations = product(*values) + + combinations = [] + for combi in val_combinations: + new_param_dict = {} + for i, val in enumerate(combi): + new_param_dict[names[i]] = val + + combinations.append(new_param_dict) + + return combinations + +def run(options): + VERBOSE = options.verbose + + ### SET SEEDS FOR REPRODUCABILITY + np.random.seed(1337) + random.seed(1337) + ### + + # load dataset and build feature representation + X, Y, extractor, mlb, fold_list, X_raw, Y_raw, tr = _build_features(options) + _print_feature_info(X, options) + _print_label_info(Y, options) + + # go to interactive mode if on + _check_interactive(options, X, Y, extractor, mlb) + + if VERBOSE: print("Performing %d-fold cross-validation..." % (options.folds if options.cross_validation else 1)) + + # prepare validation over folds + kf, validation_set_indices = _build_folds(options, fold_list) + + if options.optimization: + + def optimized_experiment(**parameters): + + current_options = _update_options(options, **parameters) + results = _run_experiment(X, Y, kf, validation_set_indices, mlb, X_raw, Y_raw, tr, current_options) + + # return the f1 score of the previous experiment + return results["f1_samples"][0] + + if options.optimization == "bayesian": + + gp_params = {"alpha": 1e-5, "kernel" : Matern(nu = 5 / 2)} + space, init_vals, num_init_vals = _make_space(options) + bayesian_optimizer = BayesianOptimization(optimized_experiment, space) + bayesian_optimizer.explore(init_vals) + bayesian_optimizer.maximize(n_iter=options.optimization_iterations - num_init_vals, + acq = 'ei', + **gp_params) + + elif options.optimization == "random": + + fmin(lambda parameters : optimized_experiment(**parameters), + _make_space(options), + algo=rand.suggest, + max_evals=options.optimization_iterations, + rstate = np.random.RandomState(1337)) + + elif options.optimization == "grid": + # perform grid-search by running every possible parameter combination + combinations = _all_option_combinations(_make_space(options)) + for combi in combinations: + optimized_experiment(**combi) + + else: + results = _run_experiment(X, Y, kf, validation_set_indices, mlb, X_raw, Y_raw, tr, options) -def fit_predict(X_test, X_train, Y_train, options, tr): - clf = create_classifier(options, Y_train.shape[1]) # --- INTERACTIVE MODE --- +def fit_predict(X_test, X_train, Y_train, options, tr, clf): if options.verbose: print("Fitting", X_train.shape[0], "samples...") clf.fit(X_train, Y_train) @@ -261,8 +557,8 @@ def fit_predict(X_test, X_train, Y_train, options, tr): return Y_pred, Y_pred_train @processify -def fit_predict_new_process(X_test, X_train, Y_train, options, tr): - return fit_predict(X_test, X_train, Y_train, options, tr) +def fit_predict_new_process(X_test, X_train, Y_train, options, tr, clf): + return fit_predict(X_test, X_train, Y_train, options, tr, clf) def create_classifier(options, num_concepts): # Learning 2 Rank algorithm name to ranklib identifier mapping @@ -285,7 +581,7 @@ def create_classifier(options, num_concepts): l2r_metric = options.l2r_metric + "@20", n_jobs = options.jobs, translation_probability = options.translation_prob) - mlp = MLP(verbose=options.verbose) + mlp = MLP(verbose=options.verbose, batch_size = options.batch_size, learning_rate = options.learning_rate, epochs = options.max_iterations) classifiers = { "nn": NearestNeighbor(use_lsh_forest=options.lshf), "brknna": BRKNeighborsClassifier(mode='a', n_neighbors=options.k, use_lsh_forest=options.lshf, @@ -306,7 +602,85 @@ def create_classifier(options, num_concepts): "rocchiodt": ClassifierStack(base_classifier=RocchioClassifier(metric = 'cosine'), n_jobs=options.jobs, n=options.k), "logregressdt": ClassifierStack(base_classifier=logregress, n_jobs=options.jobs, n=options.k), "mlp": mlp, - "nam": ThresholdingPredictor(MLP(verbose=options.verbose, final_activation='sigmoid'), alpha=options.alpha, stepsize=0.01, verbose=options.verbose), + "mlpbase" : MultiLabelSKFlow(batch_size = options.batch_size, + num_epochs=options.max_iterations, + learning_rate = options.learning_rate, + tf_model_path = options.tf_model_path, + optimize_threshold = options.optimize_threshold, + get_model = mlp_base(hidden_activation_function = options.hidden_activation_function), + patience = options.patience, + num_steps_before_validation = options.num_steps_before_validation, + bottleneck_layers = options.bottleneck_layers, + hidden_keep_prob = options.dropout, + gpu_memory_fraction = options.memory), + "mlpsoph" : MultiLabelSKFlow(batch_size = options.batch_size, + num_epochs=options.max_iterations, + learning_rate = options.learning_rate, + tf_model_path = options.tf_model_path, + optimize_threshold = options.optimize_threshold, + get_model = mlp_soph(options.dropout, options.embedding_size, + hidden_layers = options.hidden_layers, self_normalizing = options.snn, + standard_normal = options.standard_normal, + batch_norm = options.batch_norm, + hidden_activation_function = options.hidden_activation_function + ), + patience = options.patience, + num_steps_before_validation = options.num_steps_before_validation, + bottleneck_layers = options.bottleneck_layers, + hidden_keep_prob = options.dropout, + gpu_memory_fraction = options.memory, + meta_labeler_phi = options.meta_labeler_phi, + meta_labeler_alpha = options.meta_labeler_alpha, + meta_labeler_min_labels = options.meta_labeler_min_labels, + meta_labeler_max_labels = options.meta_labeler_max_labels), + "cnn": MultiLabelSKFlow(batch_size = options.batch_size, + num_epochs=options.max_iterations, + learning_rate = options.learning_rate, + tf_model_path = options.tf_model_path, + optimize_threshold = options.optimize_threshold, + patience = options.patience, + num_steps_before_validation = options.num_steps_before_validation, + get_model = cnn(options.dropout, options.embedding_size, + hidden_layers = options.hidden_layers, + pretrained_embeddings_path = options.pretrained_embeddings, + trainable_embeddings=options.trainable_embeddings, + dynamic_max_pooling_p=options.dynamic_max_pooling_p, + window_sizes = options.window_sizes, + num_filters = options.num_filters), + bottleneck_layers = options.bottleneck_layers, + hidden_keep_prob = options.dropout, + gpu_memory_fraction = options.memory, + meta_labeler_phi = options.meta_labeler_phi, + meta_labeler_alpha = options.meta_labeler_alpha, + meta_labeler_min_labels = options.meta_labeler_min_labels, + meta_labeler_max_labels = options.meta_labeler_max_labels), + "lstm": MultiLabelSKFlow(batch_size = options.batch_size, + num_epochs=options.max_iterations, + learning_rate = options.learning_rate, + tf_model_path = options.tf_model_path, + optimize_threshold = options.optimize_threshold, + patience = options.patience, + num_steps_before_validation = options.num_steps_before_validation, + get_model = lstm(options.dropout, options.embedding_size, + hidden_layers = options.hidden_layers, + pretrained_embeddings_path = options.pretrained_embeddings, + trainable_embeddings = options.trainable_embeddings, + variational_recurrent_dropout = options.variational_recurrent_dropout, + bidirectional = options.bidirectional, + aggregate_output = options.aggregate_output, + iterate_until_maxlength = options.iterate_until_maxlength, + num_last_outputs = options.pad_special_symbol), + bottleneck_layers = options.bottleneck_layers, + hidden_keep_prob = options.dropout, + gpu_memory_fraction = options.memory, + meta_labeler_phi = options.meta_labeler_phi, + meta_labeler_alpha = options.meta_labeler_alpha, + meta_labeler_min_labels = options.meta_labeler_min_labels, + meta_labeler_max_labels = options.meta_labeler_max_labels), + "nam": ThresholdingPredictor(MLP(verbose=options.verbose, final_activation='sigmoid', batch_size = options.batch_size, + learning_rate = options.learning_rate, + epochs = options.max_iterations), + alpha=options.alpha, stepsize=0.01, verbose=options.verbose), "mlpthr": LinRegStack(mlp, verbose=options.verbose), "mlpdt" : ClassifierStack(base_classifier=mlp, n_jobs=options.jobs, n=options.k) } @@ -315,7 +689,7 @@ def create_classifier(options, num_concepts): if options.bm25: trf = BM25Transformer(sublinear_tf=True if options.lsa else False, use_idf=options.idf, norm=norm, bm25_tf=True, use_bm25idf=True) - else: + elif options.terms or options.concepts: trf = TfidfTransformer(sublinear_tf=True if options.lsa else False, use_idf=options.idf, norm=norm) # Pipeline with final estimator ## @@ -326,9 +700,10 @@ def create_classifier(options, num_concepts): # svd = TruncatedSVD(n_components=options.lsa) # lsa = make_pipeline(svd, Normalizer(copy=False)) # clf = Pipeline([("trf", trf), ("lsa", lsa), ("clf", classifiers[options.clf_key])]) - else: + elif options.terms or options.concepts: clf = Pipeline([("trf", trf), ("clf", classifiers[options.clf_key])]) - + else: + clf = Pipeline([("clf", classifiers[options.clf_key])]) return clf def _generate_parsers(): @@ -367,6 +742,8 @@ def _generate_parsers(): "Run on one fold [False]") execution_options.add_argument('-X', action="store_true", dest="cross_validation", default=False, help= "Perform cross validation [False]") + execution_options.add_argument('--fixed_folds', action="store_true", dest="fixed_folds", default=False, help= + "Perform cross validation with fixed folds.") execution_options.add_argument('-i', '--interactive', action="store_true", dest="interactive", default=False, help= \ "Use whole supplied data as training set and classify new inputs from STDIN") @@ -375,10 +752,22 @@ def _generate_parsers(): detailed_options = parser.add_argument_group("Detailed Execution Options") detailed_options.add_argument('--test-size', type=float, dest='test_size', default=0.1, help= "Desired relative size for the test set [0.1]") + detailed_options.add_argument('--optimization', type=str, dest='optimization', default=None, help= + "Whether to use Random Search or Bayesian Optimization for hyperparameter search. [None]", choices = ["grid", "random", "bayesian", None]) + detailed_options.add_argument('--optimization_spaces', type=str, dest='optimization_spaces', default="default_searchspace", help= + "Path to a file that specifies the search spaces for hyperparameters [default_searchspace]") + detailed_options.add_argument('--optimization_iterations', type=int, dest='optimization_iterations', default=10, help= + "Number of iterations in hyperparameter search. [10]") + detailed_options.add_argument('--val-size', type=float, dest='validation_size', default=0., help= + "Desired relative size of the training set used as validation set [0.]") detailed_options.add_argument('--folds', type=int, dest='folds', default=10, help= "Number of folds used for cross validation [10]") detailed_options.add_argument('--toy', type=float, dest='toy_size', default=1.0, help= "Eventually use a smaller block of the data set from the very beginning. [1.0]") + detailed_options.add_argument('--extra_samples_factor', type=float, dest='extra_samples_factor', default=1.0, help= + "This option only has an effect when the '--fixed_folds' option is true. The value determines the factor 'x >= 1' by which\ + the training set is enriched with samples from the 11th fold. Hence, the total number of training data will be \ + x * size of tranining set. By default, the value is x = 1.") detailed_options.add_argument('--training-error', action="store_true", dest="training_error", default=False, help=\ "Compute training error") @@ -397,6 +786,12 @@ def _generate_parsers(): "use concepts [False]") feature_options.add_argument('-t', '--terms', action="store_true", dest="terms", default=False, help= \ "use terms [True]") + feature_options.add_argument('--charngrams', action="store_true", dest="charngrams", default=False, help= \ + "use character n-grams [True]") + feature_options.add_argument('--onehot', action="store_true", dest="onehot", default=False, help= \ + "Encode the input words as one hot. [True]") + feature_options.add_argument('--max_features', type=int, dest="max_features", default=None, help= \ + "Specify the maximal number of features to be considered for a BoW model [None, i.e., infinity]") feature_options.add_argument('-s', '--synsets', action="store_true", dest="synsets", default=False, help= \ "use synsets [False]") feature_options.add_argument('-g', '--graphscoring', dest="graph_scoring_method", type=str, default="", \ @@ -414,13 +809,17 @@ def _generate_parsers(): "Do not use IDF") feature_options.add_argument('--no-norm', action="store_false", dest="norm", default=True, help="Do not normalize values") + feature_options.add_argument('--ngram_limit', type=int, dest="ngram_limit", default=1, help= \ + "Specify the n for n-grams to take into account for token-based BoW vectorization. [1]") + feature_options.add_argument('--char_ngram_limit', type=int, dest="char_ngram_limit", default=3, help= \ + "Specify the n for character n-grams to take into account for character n-gram based BoW vectorization. [3]") # group for classifiers classifier_options = parser.add_argument_group("Classifier Options") classifier_options.add_argument('-f', '--classifier', dest="clf_key", default="nn", help= "Specify the final classifier.", choices=["nn", "brknna", "brknnb", "bbayes", "mbayes", "lsvc", "sgd", "sgddt", "rocchio", "rocchiodt", "logregress", "logregressdt", - "mlp", "listnet", "l2rdt", 'mlpthr', 'mlpdt', 'nam']) + "mlp", "listnet", "l2rdt", 'mlpthr', 'mlpdt', 'nam', 'mlpbase', "mlpsoph", "cnn", "lstm"]) classifier_options.add_argument('-a', '--alpha', dest="alpha", type=float, default=1e-7, help= \ "Specify alpha parameter for stochastic gradient descent") classifier_options.add_argument('-n', dest="k", type=int, default=1, help= @@ -434,6 +833,13 @@ def _generate_parsers(): "Performs Grid search to find optimal K") classifier_options.add_argument('-e', type=int, dest="max_iterations", default=5, help= "Determine the number of epochs for the training of several classifiers [5]") + classifier_options.add_argument('--patience', type=int, dest="patience", default=5, help= + "Specify the number of steps of no improvement in validation score before training is stopped. [5]") + classifier_options.add_argument('--num_steps_before_validation', type=int, dest="num_steps_before_validation", default=None, help= + "Specify the number of steps before evaluating on the validation set. [None]") + classifier_options.add_argument('--learning_rate', type=float, dest="learning_rate", default=None, help= + "Determine the learning rate for training of several classifiers. If set to 'None', the learning rate is automatically based on an empirical good value and \ + adapted to the batch size. [None]") classifier_options.add_argument('-P', type=str, dest="penalty", default=None, choices=['l1','l2','elasticnet'], help=\ "Penalty term for SGD and other regularized linear models") classifier_options.add_argument('--l2r-alg', type=str, dest="l2r", default="listnet", choices=['listnet','adarank','ca', 'lambdamart'], help=\ @@ -446,6 +852,71 @@ def _generate_parsers(): "Whether the ClassifierStack should make use of all label information and thus take into account possible interdependencies.") classifier_options.add_argument('--l2r-neighbors', dest="l2r_neighbors", type=int, default=45, help= "Specify n_neighbors argument for KneighborsL2RClassifier.") + classifier_options.add_argument('--batch_size', dest="batch_size", type=int, default=256, help= + "Specify batch size for neural network training.") + + # neural network specific options + neural_network_options = parser.add_argument_group("Neural Network Options") + neural_network_options.add_argument('--dropout', type=float, dest="dropout", default=0.5, help= + "Determine the keep probability for all dropout layers.") + neural_network_options.add_argument('--memory', type=float, dest="memory", default=1.0, help= + "Fraction of available GPU-memory to use for experiment.") + neural_network_options.add_argument('--embedding_size', type=int, dest="embedding_size", default=300, help= + "Determine the size of a word embedding vector (for MLP-Soph, CNN, and LSTM if embedding is learned jointly). \ + Specify --embedding_size=0 to skip the embedding layer, if applicable. [300]") + neural_network_options.add_argument('--pretrained_embeddings', type=str, dest="pretrained_embeddings", default=None, help= + "Specify the path to a file contraining pretrained word embeddings. The file must have a format where each line consists of the word\ + followed by the entries of its vectors, separated by blanks. If None is specified, the word embeddings are zero-initialized and trained\ + jointly with the classification task. [None]") + neural_network_options.add_argument('--hidden_activation_function', type=str, dest="hidden_activation_function", default="relu", help= + "Specify the activation function used on the hidden layers in MLP-Base and MLP-Soph. [relu]", choices = ["relu", "tanh", "identity", "swish"]) + neural_network_options.add_argument('--trainable_embeddings', action="store_true", dest="trainable_embeddings", default=False, help= + "Whether to keep training the pretrained embeddings further with classification the task or not. [False]") + neural_network_options.add_argument('--hidden_layers', type=int, dest="hidden_layers", nargs='+', default=[1000], help= + "Specify the number of layers and the respective number of units as a list. The i-th element of the list \ + specifies the number of units in layer i. [1000]") + neural_network_options.add_argument('--bottleneck_layers', type=int, dest="bottleneck_layers", nargs='+', default=None, help= + "Specify the number of bottleneck layers and the respective number of units as a list. The i-th element of the list \ + specifies the number of units in layer i. In contrast to the --hidden_layers option, where the respective model decides\ + how to interprete multiple hidden layers, the bottleneck layers are feed forward layers which are pluged in between \ + the last layer of a particular model (e.g. CNN, LSTM) and the output layer. (None)") + neural_network_options.add_argument('--standard_normal', action="store_true", dest="standard_normal", default=False, help= + "Whether to normalize the input features to mean = 0 and std = 1 for MLPSoph. [False]") + neural_network_options.add_argument('--batch_norm', action="store_true", dest="batch_norm", default=False, help= + "Whether to apply batch normalization after at a hidden layer in MLP. [False]") + neural_network_options.add_argument('--snn', action="store_true", dest="snn", default=False, help= + "Whether to use SELU activation and -dropout. If set to False, the activation specified in --hidden_activation_function is used. [False]") + neural_network_options.add_argument('--variational_recurrent_dropout', action="store_true", dest="variational_recurrent_dropout", default=False, help= + "Whether to perform dropout on the recurrent unit between states in addition to dropout on the aggregated output. [False]") + neural_network_options.add_argument('--bidirectional', action="store_true", dest="bidirectional", default=False, help= + "When activated, we create two instances of (potentially multi-layered) LSTMs, where one reads the input from left to right and \ + the other reads it from right to left. [False]") + neural_network_options.add_argument('--iterate_until_maxlength', action="store_true", dest="iterate_until_maxlength", default=False, help= + "When activated, the LSTM always iterates max_features steps, even if the actual sequence is shorter. Instead, it consumes\ + at each additional step the padding symbol. The outputs of steps beyond the actual sequence length are taken into account as well for output aggregation. [False]") + neural_network_options.add_argument('--aggregate_output', type=str, dest='aggregate_output', default="average", help= + "How to aggregate the outputs of an LSTM. 'last' uses the output at the last time step. 'average' takes the mean over all outputs. [average]", + choices = ["average", "last", "attention", "oe-attention", "sum"]) + neural_network_options.add_argument('--pad_special_symbol', type=int, dest="pad_special_symbol", default=0, help= + "How many special tokens to pad after each sample for OE-LSTMs. [0]") + neural_network_options.add_argument('--optimize_threshold', action="store_true", dest="optimize_threshold", default=False, help= + "Optimize the prediction threshold on validation set during training. [False]") + neural_network_options.add_argument('--dynamic_max_pooling_p', type=int, dest="dynamic_max_pooling_p", default=1, help= + "Specify the number of chunks (p) to perform max-pooling over. [1]") + neural_network_options.add_argument('--num_filters', type=int, dest="num_filters", default=100, help= + "Specify the number of filters used in a CNN (per window size). [100]") + neural_network_options.add_argument('--window_sizes', type=int, dest="window_sizes", nargs='+', default=[3,4,5], help= + "Specify the window sizes used for extracting features in a CNN. [[3,4,5]]") + neural_network_options.add_argument('--meta_labeler_min_labels', type=int, dest="meta_labeler_min_labels", default=1, help= + "Specify the minimum number of labels to assign the meta labeler can predict. [1]") + neural_network_options.add_argument('--meta_labeler_max_labels', type=int, dest="meta_labeler_max_labels", default=None, help= + "Specify the maximum number of labels to assign the meta labeler can predict. When 'None' is specified, the maximum \ + is computed from the data. [None]") + detailed_options.add_argument('--meta_labeler_phi', type=str, dest='meta_labeler_phi', default=None, help= + "Specify whether to predict number of labels from 'score' or from 'content', or whether to use meta labeler at all (None). [None]", + choices = ["content", "score"]) + neural_network_options.add_argument('--meta_labeler_alpha', type=float, dest="meta_labeler_alpha", default=0.1, help= + "The alpha-weight of predicting the correct number of labels when doing meta-labeling. [0.1]") # persistence_options persistence_options = parser.add_argument_group("Feature Persistence Options") @@ -455,6 +926,8 @@ def _generate_parsers(): "Persisted features will be recalculated and overwritten.") persistence_options.add_argument('--persist_to', dest="persist_to", default=os.curdir + os.sep + 'persistence', help= "Path to persist files.") + persistence_options.add_argument("--tf-model-path", dest="tf_model_path", default=".tmp_best_models", help= + "Directory to store best models for early stopping.") return meta_parser, parser diff --git a/Code/lucid_ml/tftests/__init__.py b/Code/lucid_ml/tftests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Code/lucid_ml/tftests/testutils.py b/Code/lucid_ml/tftests/testutils.py new file mode 100644 index 0000000..a0e3e21 --- /dev/null +++ b/Code/lucid_ml/tftests/testutils.py @@ -0,0 +1,90 @@ +from utils.tf_utils import * +import tensorflow as tf + +class SequenceTests(tf.test.TestCase): + + def testSequenceLength(self): + with self.test_session(): + sequences = [[3, 4, 1, 3, 6, 0, 0, 0], + [3, 4, 1, 3, 6, 4, 0, 0], + [3, 4, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [3, 0, 0, 0, 0, 0, 0, 0], + [3, 4, 1, 3, 6, 1, 1, 1]] + lengths = sequence_length(sequences) + self.assertAllEqual(lengths.eval(), [5,6,3,0,1,8]) + + +class AggregationTests(tf.test.TestCase): + + def testAverage(self): + with self.test_session(): + # output dimensions = [3, 4, 2] + sequences = [[3, 4, 1, 3], + [0, 0, 0, 0], + [3, 0, 0, 0], + [1, 2, 0, 0]] + outputs = [[[1., 1.], [1., 1.], [2., 2.5], [3., 4.]], + [[1., 2.], [3., 4.], [0., 0.], [0., 0.]], + [[5., 6.], [10., 10.,], [10., 10.], [10., 10.]], + [[10., 12.], [0., 0.,],[0., 0.,],[0., 0.,]]] + outputs = tf.constant(outputs, dtype = tf.float32) + expected_result = [[1.75, 2.125], + [0., 0.], + [5., 6.], + [5., 6.]] + lengths = sequence_length(sequences) + result = average_outputs(outputs, lengths) + self.assertAllEqual(result.eval(), expected_result) + + +class DynamicMaxPoolingTests(tf.test.TestCase): + + def testPoolingMask(self): + pass + + def testDynamicMaxPooling(self): + with self.test_session(): + + # add dimension to detector stage + window_size = 1 + num_filters = 3 + max_length = 3 + map_size = max_length - window_size + 1 + #detector : tensor of shape [batch_size, map_size, 1, num_filters] + detector = [[[1., 1., 3.], [2., 2.5, 1.5], [3., 4., 2.]], + [[1., 2., 0.5], [0., 0., 1.], [0., 0., 1.]], + [[5., 6., 20.], [10., 10., 9.5], [10., 10., 11.]], + [[10., 12., 0.], [4., 3., 1.],[2., 2., 7.]]] + detector = tf.expand_dims(detector, axis = 2) + + seq_length = tf.constant([0, 1, 2, 3], dtype = tf.float32) + + # p = 1 + result = dynamic_max_pooling(detector, seq_length, max_length, num_filters, window_size, dynamic_max_pooling_p = 1) + expected_result = [[0., 0., 0.], + [1., 2., 0.5], + [10., 10., 20.], + [10., 12., 7.]] + self.assertAllEqual(result.eval(), expected_result) + + # full sequence length, p = 2, so we check if we can handle the case where text-length divided by number of chunks is not an integer + seq_length = tf.constant([3, 3, 3, 3], dtype = tf.float32) + result = dynamic_max_pooling(detector, seq_length, max_length, num_filters, window_size, dynamic_max_pooling_p = 2) + expected_result = [[2., 2.5, 3., 3., 4., 2.], + [1., 2., 1., 0., 0., 1.], + [10., 10., 20., 10., 10., 11.], + [10., 12., 1., 2., 2., 7. ]] + self.assertAllEqual(result.eval(), expected_result) + + # full sequence length, p = 3, which should collapse the sequence dimension of the detector outputs in this case + seq_length = tf.constant([3, 3, 3, 3], dtype = tf.float32) + result = dynamic_max_pooling(detector, seq_length, max_length, num_filters, window_size, dynamic_max_pooling_p = 3) + expected_result = [[1., 1., 3., 2., 2.5, 1.5, 3., 4., 2.], + [1., 2., 0.5, 0., 0., 1., 0., 0., 1.], + [5., 6., 20., 10., 10., 9.5, 10., 10., 11.], + [10., 12., 0., 4., 3., 1., 2., 2., 7.]] + self.assertAllEqual(result.eval(), expected_result) + +if __name__ == '__main__': + tf.test.main() \ No newline at end of file diff --git a/Code/lucid_ml/utils/Extractor.py b/Code/lucid_ml/utils/Extractor.py index 43dcece..efe16b2 100644 --- a/Code/lucid_ml/utils/Extractor.py +++ b/Code/lucid_ml/utils/Extractor.py @@ -3,22 +3,80 @@ import json import os import random +import pandas as pd +import math from utils.thesaurus_reader import ThesaurusReader -# paths to required resources: -# [0: titles/documents, 1: gold, 2: thesaurus] -def load_dataset(DATA_PATHS, key='econ62k', fulltext=False): +def split_labels(labels_string, delimiter = ","): + return labels_string.split(delimiter) + +def load_data(df, fulltext, fixed_folds, label_delimiter = ","): + if fulltext: - data = load_documents(DATA_PATHS[key]['X']) + content_column = "fulltext_path" else: - data = load_titles(DATA_PATHS[key]['X']) - gold = load_gold(DATA_PATHS[key]['y']) - data_list, gold_list = reduce_dicts(data, gold) - tr = ThesaurusReader(DATA_PATHS[key]['thes']) + content_column = "title" + + # bring title/fulltext in dictionary format + content = dict() + for row in df.iterrows(): + content[row[1].loc["id"]] = row[1][content_column] + + # bring goldstandard in dictionary format + gold = dict() + for row in df.iterrows(): + gold[row[1]["id"]] = split_labels(row[1]["labels"], delimiter = label_delimiter) + + # by default, there is only one fold + folds = dict() + folds.update({key : 0 for key in gold}) + if fixed_folds: + for row in df.iterrows(): + folds[row[1]["id"]] = int(row[1]["fold"]) + + data_list, gold_list, fold_list = reduce_dicts([content, gold, folds]) + + if fulltext: + fulltext_indices = [index for index, x in enumerate(data_list) if type(x) == str or not math.isnan(x)] + + def elems_by_indices(some_list): + return [some_list[i] for i in fulltext_indices] + + data_list, gold_list, fold_list = elems_by_indices(data_list), elems_by_indices(gold_list), elems_by_indices(fold_list) + print(len(data_list)) - return data_list, gold_list, tr + return data_list, gold_list, fold_list +# paths to required resources: +# [0: titles/documents, 1: gold, 2: thesaurus] +def load_dataset(DATA_PATHS, key='econ62k', fulltext=False, fixed_folds = False): + data_set = DATA_PATHS[key] + + dataset_format = data_set["format"] if "format" in data_set else "separate" + + if dataset_format == "separate": + + if fulltext: + data = load_documents(DATA_PATHS[key]['X']) + else: + data = load_titles(DATA_PATHS[key]['X']) + gold = load_gold(DATA_PATHS[key]['y']) + data_list, gold_list = reduce_dicts([data, gold]) + tr = ThesaurusReader(DATA_PATHS[key]['thes']) + + return data_list, gold_list, tr, None + + elif dataset_format == "combined": + # extract the available folds and keep each folds samples in a separate list + df = pd.read_csv(data_set["X"]) + data_list, gold_list, fold_list = load_data(df, fulltext, fixed_folds, label_delimiter = "," if "label_delimiter" not in data_set else data_set["label_delimiter"]) + tr = None if "thes" not in data_set else ThesaurusReader(data_set['thes']) + return data_list, gold_list, tr, fold_list + + else: + print("Format not recognized:", dataset_format) + raise ValueError("No such format: " + dataset_format) # expected format: \t def load_titles(titles_path): @@ -57,21 +115,21 @@ def load_gold(path): return gold -def reduce_dicts(titles, gold, shuffle=False): - """ reduce 2 dictionaries to 2 lists, +def reduce_dicts(dicts, shuffle=False): + """ reduces a list of dictionaries to a list of lists, providing 'same index' iff 'same key in the dictionary' """ - titles, gold = dict(titles), dict(gold) - titles_list = [] - gold_list = [] - for key in titles: - titles_list.append(titles[key]) - gold_list.append(gold[key]) + dicts = list(map(dict, dicts)) + lists = [[] for i in range(len(dicts))] + + for key in dicts[0]: + for i, some_dict in enumerate(dicts): + lists[i].append(some_dict[key]) if shuffle: # this is the only way to shuffle them - zipped = list(zip(titles_list, gold_list)) + zipped = list(zip(lists)) random.shuffle(zipped) - titles_list, gold_list = zip(*zipped) + lists = zip(*zipped) - return titles_list, gold_list + return tuple(lists) diff --git a/Code/lucid_ml/utils/nltk_normalization.py b/Code/lucid_ml/utils/nltk_normalization.py index 58c84fd..9c5cd82 100644 --- a/Code/lucid_ml/utils/nltk_normalization.py +++ b/Code/lucid_ml/utils/nltk_normalization.py @@ -6,6 +6,7 @@ _alphabet = set(string.ascii_lowercase + string.digits + ' ') word_regexp = r"(?u)\b[a-zA-Z_][a-zA-Z_]+\b" +character_regexp = r"(?u)\b[a-zA-Z_0-9]\b" class NltkNormalizer: def __init__(self): diff --git a/Code/lucid_ml/utils/text_encoding.py b/Code/lucid_ml/utils/text_encoding.py new file mode 100644 index 0000000..b8a6048 --- /dev/null +++ b/Code/lucid_ml/utils/text_encoding.py @@ -0,0 +1,175 @@ +from sklearn.base import BaseEstimator, TransformerMixin + +from utils.nltk_normalization import NltkNormalizer + +import numpy as np + + +class TextEncoder(BaseEstimator, TransformerMixin): + """ + Sk-learn transformer that turns raw text into a fixed-length sequence of one-hot vectors. + Length of the sequence is either set by the user or the length of the longest sample that is + passed to the fit() function. If a sequence is shorter than the maximum length, it is padded with zeros. + If it is longer, the text is truncated to the maximum length. + + If 'pretrained' is specified, it is interpreted as a word embedding file (word2vec format). Any token that does not have an entry in the + embedding table is discarded. + + Parameters + ---------- + tokenize: function, default = NltkNormalizer().split_and_normalize + Function that turns a raw string into a sequence of tokens. + input_format: str, default = content + Determines whether the samples passed to fit() or transform() are processed directly ("content") or if they are interpreted as a path ("filename"). + max_words: int, default = None + Determines the maximum sequence length. If None, the maximum sequence length is determined from the samples passed to fit(). + pretrained: str, default = None + Path to pretrained word embeddings. + restrict_pretrained: bool, default = True + If true, a word embedding table is generated that only contains those words which appear among the samples. + pad_special_symbol: int, default = 0 + Number of times to repeat a special symbol. + """ + def __init__(self, + tokenize = NltkNormalizer().split_and_normalize, + input_format = "content", + max_words = None, + pretrained = None, + restrict_pretrained = True, + pad_special_symbol = 0): + + self.tokenize = tokenize + self.input = input_format + self.max_words = max_words + self.pretrained = pretrained + self.restrict_pretrained = restrict_pretrained + self.pad_special_symbol = pad_special_symbol + + def _maybe_load_text(self, text): + if self.input == "filename": + with open(text, 'r') as text_file: + text = text_file.read() + + return text + + def _limit_num_words(self, words, max_length): + if self.max_words is not None: + return words[:max_length] + else: + return words + + @staticmethod + def _load_pretrained_vocabulary(filename, word_restrictions): + mapping = {} + + with open(filename + ".tmp", 'w') as temp_embedding_file: + + with open(filename,'r') as embedding_file: + embedding_size = int(embedding_file.readline().strip().split(" ")[1]) + + i = 0 + for line in embedding_file.readlines(): + row = line.strip().split(' ') + + # make sure we dont use escape sequences and so on + if len(row) == embedding_size + 1: + if row[0] in mapping: + print(row[0], "is already in mapping") + elif word_restrictions is None or row[0] in word_restrictions: + mapping[row[0]] = i + i += 1 + temp_embedding_file.write(line) + return mapping, i + + def _extract_words(self, text): + # if full-text: load it first + text = self._maybe_load_text(text) + + # tokenize training text + words = self.tokenize(text) + words = self._limit_num_words(words, self.max_words) + + return words + + def _set_of_all_words(self, X): + all_words = set() + for text in X: + all_words.update(set(self._extract_words(text))) + + return all_words + + def fit(self, X, y = None): + + if self.pretrained is None: + mapping = {} + max_index = 1 + max_length = 0 + for text in X: + + words = self._extract_words(text) + + # build mapping from word to index + + for word in words: + if word not in mapping: + mapping[word] = max_index + max_index += 1 + + # determine maximum length of a text for padding + if len(words) > max_length: + max_length = len(words) + + + else: + + if self.restrict_pretrained: + all_words = self._set_of_all_words(X) + else: + all_words = None + + mapping, max_index = TextEncoder._load_pretrained_vocabulary(self.pretrained, all_words) + max_length = 0 + for text in X: + + words = self._extract_words(text) + + if len(words) > max_length: + max_length = len(words) + + # need to account for the special symbol + if self.pad_special_symbol > 0: + max_index += 1 + max_length = max_length + self.pad_special_symbol + + # save variables required for transformation step + self.mapping = mapping + self.max_index = max_index - 1 + self.max_length = max_length + + return self + + def transform(self, X, y = None): + + encoding_matrix = np.zeros((len(X), self.max_length), dtype = np.int32) + for i, text in enumerate(X): + + text = self._maybe_load_text(text) + + # tokenize test text + words = self.tokenize(text) + # make sure we do not exceed the maximum length from training samples + words = self._limit_num_words(words, self.max_length - self.pad_special_symbol) + + # apply mapping from word to integer + id_sequence = np.array([self.mapping[word] for word in words if word in self.mapping]) + + # add special padding token if necessary + if self.pad_special_symbol > 0: + padding_sequence = np.array([self.max_index for _ in range(self.pad_special_symbol)]) + id_sequence = np.concatenate((id_sequence, padding_sequence)) + + encoding_matrix[i, :len(id_sequence)] = id_sequence + + max_index_column = np.zeros((len(X), 1), dtype = np.int32) + max_index_column.fill(self.max_index) + return np.hstack((encoding_matrix, max_index_column)) diff --git a/Code/lucid_ml/utils/tf_utils.py b/Code/lucid_ml/utils/tf_utils.py new file mode 100644 index 0000000..d07197a --- /dev/null +++ b/Code/lucid_ml/utils/tf_utils.py @@ -0,0 +1,151 @@ +import tensorflow as tf +import numpy as np +from scipy.sparse import csr_matrix + +def tf_normalize(X, input_tensor): + """ + Given data as a sparse matrix X, this function scales an input tensor such that each + column in the output tensor has mean 0 and variance 1. + + >>> X = np.array([[0, 0, 0, 0, 1, 1, 1, 2, 3], + ... [1, 1, 0, 2, 3, 4, 5, 1, 0], + ... [0, 0 ,4, 5, -1, 4, 2, 1, 0]]) + >>> means = np.mean(X, axis = 0) + >>> stds = np.std(X, axis = 0) + >>> scaled_X = (X - means) / stds + >>> X = csr_matrix(X) + >>> input_tensor = tf.placeholder(dtype = tf.float32, shape = [3, 9]) + >>> output_tensor = tf_normalize(X, input_tensor) + >>> sess = tf.Session() + >>> output = sess.run(output_tensor, feed_dict = {input_tensor : X.toarray()}) + >>> output + array([[-0.70710683, -0.70710683, -0.70710683, -1.13555002, 0. , + -1.41421354, -0.98058075, 1.41421342, 1.41421354], + [ 1.41421342, 1.41421342, -0.70710683, -0.16222139, 1.2247448 , + 0.70710677, 1.37281287, -0.70710689, -0.70710677], + [-0.70710683, -0.70710683, 1.41421342, 1.29777145, -1.2247448 , + 0.70710677, -0.39223233, -0.70710689, -0.70710677]], dtype=float32) + """ + m = X.mean(axis = 0) + + # because we are dealing with sparse matrices, we need to compute the variance as + # E[X^2] - E[X]^2 + X_square = X.power(2) + m_square = X_square.mean(axis = 0) + v = m_square - np.power(m, 2) + s = np.sqrt(v) + + #make sure not to divide by zero when scaling + s[s == 0] = 1 + + m = tf.constant(m, dtype = tf.float32) + s = tf.constant(s, dtype = tf.float32) + + scaled_input = (input_tensor - m) / s + + return scaled_input + +def sequence_length(sequence): + """ + Takes as input a tensor of dimensions [batch_size, max_len] which encodes some sequence of maximum length max_len as + a sequence of positive ids. For positions beyond the length of the actual sequence, the id is assumed to be zero (i.e., zero is used for padding). + + This function returns a one-dimensional tensor of size [batch_size], where each entry denotes the length of the corresponding sequence. + """ + used = tf.sign(sequence) + length = tf.reduce_sum(used, 1) + length = tf.cast(length, tf.int32) + return length + +def dynamic_max_pooling(detector, seq_length, max_length, num_filters, window_size, dynamic_max_pooling_p = 1): + """ + Performs dynamic max pooling as in [XML-CNN]. That is, it splits the outputs of the detector stage into 'p' chunks, performs + max-pooling on each chunk and concatenates the outputs. The function assumes the CNN from which the detector stage results to use 'VALID' + padding as well as strides = [1, 1, 1, 1]. + + If the length of the sequence can not be divided into evenly sized chunks, we make the last chunk contain the remainder of the text. + + References + ---------- + [XML-CNN] J Liu, WC Chang, Y Wu, Y Yang + "Deep Learning for Extreme Multi-label Text Classification" + Proceedings of the 40th International ACM SIGIR Conference on Research and Development in Information Retrieval + + Parameters + ---------- + detector : tensor of shape [batch_size, map_size, 1, num_filters] + The output from the 'detector' stage of a CNN with a configuration (number of filters, window_size, ...) such that it results in + a tensor with the according feature map size and number of filters. + seq_length : tensor of shape [batch_size] containing the length of the text. + max_length : maximum length of the sequence + num_filters : number of filters used in the CNN + window_size : window size used in the CNN + dynamic_max_pooling_p : the number of chunks 'p' to split the outputs of the detector stage into. + + Returns + ------- + Tensor of shape [batch_size, p * num_filters] + The concatenated outputs of the max-pooling operations over the 'p' chunks. + """ + # assumptions of CNN from which the detector outputs result + stride = [1,1,1,1] + detector_output_length = seq_length - window_size + 1 + + # dynamic max-pooling: extract maximum for each chunk + chunks_size = tf.ceil(tf.divide(detector_output_length, dynamic_max_pooling_p)) + chunk_poolings = [] + for i in range(dynamic_max_pooling_p): + + # make sure we don't get out of bounds at end of sequence + cur_chunk_size = chunks_size if i != dynamic_max_pooling_p - 1 or dynamic_max_pooling_p == 1 else detector_output_length - i * chunks_size + + # create a mask for the entire sequence where only those are selected which are in the current chunk + start_indices = i * chunks_size + end_indices = i * chunks_size + cur_chunk_size + + neg_mask_start = 1 - tf.cast(tf.sequence_mask(start_indices, maxlen = max_length - window_size + 1), tf.float32) + mask_end = tf.cast(tf.sequence_mask(end_indices, maxlen = max_length - window_size + 1), tf.float32) + final_mask = tf.multiply(neg_mask_start, mask_end) + final_mask = tf.expand_dims(final_mask, axis = 2) + final_mask = tf.expand_dims(final_mask, axis = 3) + + extracted_chunk = tf.multiply(final_mask, detector) + pooling = tf.nn.max_pool(extracted_chunk, + ksize = [1, max_length - window_size + 1, 1, 1], + strides = stride, + padding = "VALID") + pooling = tf.reshape(pooling, [-1, num_filters]) + chunk_poolings.append(pooling) + + concatenated_pooled_chunks = tf.concat(chunk_poolings, 1) + return concatenated_pooled_chunks + +def average_outputs(outputs, seq_length): + """ + Given the padded outputs of an RNN and the actual length of the sequence, this function computes the average + over all (non-padded) outputs. In the special case where the length is 0, the function returns 0. + + Parameters + ---------- + outputs : tensor of shape [batch_size, max_length, output_dimensions] + The output from an RNN with hidden representation size 'output_dimensions'. + seq_length : tensor of shape [batch_size] containing the number of valid outputs in 'outputs'. + + Returns + ------- + Tensor of shape [batch_size, output_dimensions] + The average over all outputs in the sequence. + """ + # average over outputs at all time steps + seq_mask = tf.cast(tf.sequence_mask(seq_length, maxlen = outputs.get_shape().as_list()[1]), tf.float32) + seq_mask = tf.expand_dims(seq_mask, axis = 2) + outputs = outputs * seq_mask + output_state = tf.reduce_sum(outputs, axis = 1) + seq_length_reshaped = tf.cast(tf.reshape(seq_length, [-1, 1]), tf.float32) + minimum_length = tf.ones_like(seq_length_reshaped, dtype=tf.float32) + output_state = tf.div(output_state, tf.maximum(seq_length_reshaped, minimum_length)) + return output_state + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/Code/requirements-stable.txt b/Code/requirements-stable.txt index 75e4731..c562878 100644 --- a/Code/requirements-stable.txt +++ b/Code/requirements-stable.txt @@ -24,6 +24,6 @@ singledispatch==3.4.0.3 six==1.12.0 sklearn==0.0 tensorboard==1.12.2 -tensorflow==1.12.0 +tensorflow==1.4.0 termcolor==1.1.0 Werkzeug==0.14.1 diff --git a/Code/requirements.txt b/Code/requirements.txt index af65f9a..69165c6 100644 --- a/Code/requirements.txt +++ b/Code/requirements.txt @@ -8,5 +8,8 @@ nltk datrie decorator pandas -tensorflow +tensorflow==1.4 keras +h5py +bayesian-optimization +hyperopt diff --git a/Experiments/cnn_final_experiments.cfg b/Experiments/cnn_final_experiments.cfg new file mode 100644 index 0000000..d439067 --- /dev/null +++ b/Experiments/cnn_final_experiments.cfg @@ -0,0 +1,17 @@ +# these experiments are used to compute the final scores +# singlefold: +# python run.py -f cnn --onehot -k file_paths.json --fixed_folds --folds=1 -v --batch_size=256 -e 200 --val-size=0.2 -C ../../Experiments/cnn_final_experiments.cfg -o cnn_final_experiments_singlefold.csv --optimize_threshold --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=250 --num_steps_before_validation=2000 --embedding_size=300 --pretrained_embeddings=<path/to/embedding/file/in/w2v/format> --trainable_embeddings --learning_rate=0.001 --dropout=0.75 --window_sizes 2 3 4 5 8 +# 10-fold-crossvalidation: +# python run.py -f cnn --onehot -k file_paths_ks.json --fixed_folds --folds=10 -v --batch_size=256 -e 200 --val-size=0.2 -C ../../Experiments/cnn_final_experiments.cfg -o cnn_final_experiments.csv --optimize_threshold --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=250 --num_steps_before_validation=2000 --embedding_size=300 --pretrained_embeddings=<path/to/embedding/file/in/w2v/format> --trainable_embeddings --learning_rate=0.001 --dropout=0.75 --window_sizes 2 3 4 5 8 +# PubMed +--extra_samples_factor=1 -K pubmed --dynamic_max_pooling_p=1 --bottleneck_layers 500 --num_filters=400 +--extra_samples_factor=2 -K pubmed --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=4 -K pubmed --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=8 -K pubmed --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=100 -K pubmed --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +# EconBiz +--extra_samples_factor=1 -K econbiz --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=2 -K econbiz --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=4 -K econbiz --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=8 -K econbiz --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 +--extra_samples_factor=100 -K econbiz --dynamic_max_pooling_p=1 --bottleneck_layer 500 --num_filters=400 diff --git a/Experiments/final_base_mlp_experiments.cfg b/Experiments/final_base_mlp_experiments.cfg new file mode 100644 index 0000000..896c5f2 --- /dev/null +++ b/Experiments/final_base_mlp_experiments.cfg @@ -0,0 +1,17 @@ +# experiments for final evaluation with increasing samples factors for base mlp +# for single fold: +# python run.py -f basemlp -t -k file_paths.json --fixed_folds --folds=1 -v --batch_size=256 -e 200 --val-size=0.2 --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=25000 --num_steps_before_validation=2000 --learning_rate=0.001 --dropout=0.5 --memory=0.3 -C ../../Experiments/final_base_mlp_experiments.cfg -o final_base_mlp_singlefold.csv --embedding_size=0 --optimize_threshold +# for cross-validation: +# python run.py -f basemlp -t -k file_paths_ks.json --fixed_folds --folds=10 -v --batch_size=256 -e 200 --val-size=0.2 --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=25000 --num_steps_before_validation=2000 --learning_rate=0.001 --dropout=0.5 --memory=0.3 -C ../../Experiments/final_base_mlp_experiments.cfg -o final_base_mlp_crossvalidated.csv --embedding_size=0 --optimize_threshold +# EconBiz +--extra_samples_factor=1 -K econbiz +--extra_samples_factor=2 -K econbiz +--extra_samples_factor=4 -K econbiz +--extra_samples_factor=8 -K econbiz +--extra_samples_factor=100 -K econbiz +# PubMed +--extra_samples_factor=1 -K pubmed +--extra_samples_factor=2 -K pubmed +--extra_samples_factor=4 -K pubmed +--extra_samples_factor=8 -K pubmed +--extra_samples_factor=100 -K pubmed diff --git a/Experiments/final_mlp_experiments.cfg b/Experiments/final_mlp_experiments.cfg new file mode 100644 index 0000000..ccdde96 --- /dev/null +++ b/Experiments/final_mlp_experiments.cfg @@ -0,0 +1,17 @@ +# experiments for final evaluation with increasing samples factors +# for single fold: +# python run.py -f mlpsoph -t -k file_paths.json --fixed_folds --folds=1 -v --batch_size=256 -e 200 --val-size=0.2 --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=25000 --num_steps_before_validation=2000 --learning_rate=0.001 --dropout=0.5 --memory=0.3 -C ../../Experiments/final_mlp_experiments.cfg -o final_mlp.csv --embedding_size=0 --ngram_limit=2 --optimize_threshold +# for cross-validation: +# python run.py -f mlpsoph -t -k file_paths_ks.json --fixed_folds --folds=10 -v --batch_size=256 -e 200 --val-size=0.2 --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=25000 --num_steps_before_validation=2000 --learning_rate=0.001 --dropout=0.5 --memory=0.3 -C ../../Experiments/final_mlp_experiments.cfg -o final_mlp.csv --embedding_size=0 --ngram_limit=2 --optimize_threshold +# EconBiz +--extra_samples_factor=1 -K econbiz --hidden_layers 2000 +--extra_samples_factor=2 -K econbiz --hidden_layers 2000 +--extra_samples_factor=4 -K econbiz --hidden_layers 2000 +--extra_samples_factor=8 -K econbiz --hidden_layers 2000 +--extra_samples_factor=100 -K econbiz --hidden_layers 2000 +# PubMed +--extra_samples_factor=1 -K pubmed --hidden_layers 1000 1000 --batch_norm --dropout=1. +--extra_samples_factor=2 -K pubmed --hidden_layers 1000 1000 --batch_norm --dropout=1. +--extra_samples_factor=4 -K pubmed --hidden_layers 1000 1000 --batch_norm --dropout=1. +--extra_samples_factor=8 -K pubmed --hidden_layers 1000 1000 --batch_norm --dropout=1. +--extra_samples_factor=100 -K pubmed --hidden_layers 1000 1000 --batch_norm --dropout=1. diff --git a/Experiments/lstm_experiments_final.cfg b/Experiments/lstm_experiments_final.cfg new file mode 100644 index 0000000..2ef7a76 --- /dev/null +++ b/Experiments/lstm_experiments_final.cfg @@ -0,0 +1,17 @@ +# experiments for final run of LSTM +# single fold +# python run.py -f lstm --onehot -k file_paths.json --fixed_folds --folds=1 -v --batch_size=256 -e 200 --val-size=0.2 -C ../../Experiments/lstm_experiments_final.cfg -o lstm_experiments_final_singlefold.csv --optimize_threshold --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=250 --num_steps_before_validation=2000 --embedding_size=300 --pretrained_embeddings=<path/to/embedding/file/in/w2v/format> --trainable_embeddings --bidirectional --aggregate_output=attention +# crossvalidation +# # python run.py -f lstm --onehot -k file_paths_ks.json --fixed_folds --folds=10 -v --batch_size=256 -e 200 --val-size=0.2 -C ../../Experiments/lstm_experiments_final.cfg -o lstm_experiments_final_crossval.csv --optimize_threshold --patience=10 --tf-model-path=<path/to/temporary/folder> --max_features=250 --num_steps_before_validation=2000 --embedding_size=300 --pretrained_embeddings=<path/to/embedding/file/in/w2v/format> --trainable_embeddings --bidirectional --aggregate_output=attention +# PubMed +-K pubmed --extra_samples_factor=1 --learning_rate=0.001 --dropout=0.75 --hidden_layers 1536 +-K pubmed --extra_samples_factor=2 --learning_rate=0.001 --dropout=0.75 --hidden_layers 1536 +-K pubmed --extra_samples_factor=4 --learning_rate=0.001 --dropout=0.75 --hidden_layers 1536 +-K pubmed --extra_samples_factor=8 --learning_rate=0.001 --dropout=0.75 --hidden_layers 1536 +-K pubmed --extra_samples_factor=100 --learning_rate=0.001 --dropout=0.75 --hidden_layers 1536 +# EconBiz +-K econbiz --extra_samples_factor=1 --learning_rate=0.001 --dropout=0.5 --hidden_layers 1536 +-K econbiz --extra_samples_factor=2 --learning_rate=0.001 --dropout=0.5 --hidden_layers 1536 +-K econbiz --extra_samples_factor=4 --learning_rate=0.001 --dropout=0.5 --hidden_layers 1536 +-K econbiz --extra_samples_factor=8 --learning_rate=0.001 --dropout=0.5 --hidden_layers 1536 +-K econbiz --extra_samples_factor=100 --learning_rate=0.001 --dropout=0.5 --hidden_layers 1536 diff --git a/LICENSE b/LICENSE index c6779c1..a68397a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2016, Dennis Brunsch, Lukas Galke, Florian Mai, Alan Schelten, Kiel University +Copyright (c) 2016-2018, Dennis Brunsch, Lukas Galke, Florian Mai, Alan Schelten, Kiel University All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -9,4 +9,4 @@ Redistribution and use in source and binary forms, with or without modification, 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 3df23f1..eaa7661 100644 --- a/README.md +++ b/README.md @@ -74,3 +74,17 @@ python3 run.py -tf sgd -k Code/lucid_ml/file_paths.json -K example-titles --inte where `file_paths.json` should contain the key given by `-K` specifying the paths to data (`X`), the gold standard (`y`), and the thesaurus (`thes`). + +## Using Deep Learning for Title-Based Semantic Subject Indexing to Reach Competitive Performance to Full-Text + +This repository has merged the code for the JCDL paper [Using Deep Learning for Title-Based Semantic Subject Indexing to Reach Competitive Performance to Full-Text](https://arxiv.org/abs/1801.06717) from Florian Mai's [fork](https://github.com/florianmai/Quadflor). + +## Replicating the results of the JCDL paper + +In order to enhance the reproducability of our study, we uploaded a copy of the title datasets to Kaggle. Moreover, we provide the configurations used to produce the results from the paper. + +To rerun any of the (title) experiments, do the following: +1. Download the [econbiz.csv and pubmed.csv](https://www.kaggle.com/hsrobo/titlebased-semantic-subject-indexing) files, respectively, and copy them to the folder *Resources*. +2. Open the .cfg file of the respective method that you want to run (MLP, BaseMLP, CNN, or LSTM) from the *Experiments* folder. Copy the command in the third (if you want to evaluate on a single fold) or fifth (if you want to do a full 10-fold-cross-validation) line. +4. In the command, adjust the parameter for the option --tf-model-path parameter (specifies where to save the weights of the models, which can be gigabytes, so make sure you have enough disk space), and the --pretrained_embeddings parameter to the location of the GloVe word vectors file, which you need to download [here](https://nlp.stanford.edu/projects/glove/). +5. *cd* to the folder *Code/lucid_ml* and run the command. diff --git a/paper_long.pdf b/paper_long.pdf new file mode 100644 index 0000000..0a5eda5 Binary files /dev/null and b/paper_long.pdf differ