diff --git a/pylearn2/models/mlp.py b/pylearn2/models/mlp.py index e833f96a14..4106f87d4a 100755 --- a/pylearn2/models/mlp.py +++ b/pylearn2/models/mlp.py @@ -3812,6 +3812,52 @@ def L1WeightDecay(*args, **kwargs): return _L1WD(*args, **kwargs) +class QuantileRegression(Linear): + """ + A linear layer for quantile regression. + + A QuantileRegression (http://en.wikipedia.org/wiki/Quantile_regression) + is a linear layer that uses a specific cost that makes it possible to get + an estimator of a specific percentile of a posterior distribution. + + Parameters + ---------- + layer_name: str + The layer name + percentile: float (0 < percentile < 1) + Percentile being estimated. + + """ + def __init__(self, + layer_name, + percentile=0.2, + **kargs): + Linear.__init__(self, 1, layer_name, **kargs) + self.percentile = percentile + + @wraps(Layer.get_layer_monitoring_channels) + def get_layer_monitoring_channels(self, + state_below=None, + state=None, + targets=None): + rval = Linear.get_layer_monitoring_channels( + self, + state_below, + state, + targets) + assert isinstance(rval, OrderedDict) + if targets: + rval['qcost'] = (T.abs_(targets - state) * (0.5 + + (self.percentile - 0.5) * + T.sgn(targets - state))).mean() + return rval + + @wraps(Layer.cost_matrix) + def cost_matrix(self, Y, Y_hat): + return T.abs_(Y - Y_hat) * (0.5 + (self.percentile - 0.5) * + T.sgn(Y - Y_hat)) + + class LinearGaussian(Linear): """ diff --git a/pylearn2/models/tests/test_mlp.py b/pylearn2/models/tests/test_mlp.py index 88b9169504..7aa19d9269 100644 --- a/pylearn2/models/tests/test_mlp.py +++ b/pylearn2/models/tests/test_mlp.py @@ -24,12 +24,12 @@ exhaustive_dropout_average, sampled_dropout_average, CompositeLayer, max_pool, mean_pool, pool_dnn, - SigmoidConvNonlinearity, ConvElemwise) + SigmoidConvNonlinearity, ConvElemwise, + QuantileRegression) from pylearn2.space import VectorSpace, CompositeSpace, Conv2DSpace from pylearn2.utils import is_iterable, sharedX from pylearn2.expr.nnet import pseudoinverse_softmax_numpy - class IdentityLayer(Linear): dropout_input_mask_value = -np.inf @@ -1389,3 +1389,38 @@ def test_pooling_with_anon_variable(): image_shape=im_shp, try_dnn=False) pool_1 = mean_pool(X_sym, pool_shape=shp, pool_stride=strd, image_shape=im_shp) + + +def test_quantile_regression(): + """ + Create a VectorSpacesDataset with two inputs (features0 and features1) + and train an MLP which takes both inputs for 1 epoch. + """ + np.random.seed(2) + nb_rows = 1000 + X = np.random.normal(size=(nb_rows, 2)).astype(theano.config.floatX) + noise = np.random.rand(nb_rows, 1) # X[:, 0:1] * + coeffs = np.array([[3.], [4.]]) + y_0 = np.dot(X, coeffs) + y = y_0 + noise + dataset = DenseDesignMatrix(X=X, y=y) + for percentile in [0.22, 0.5, 0.65]: + mlp = MLP( + nvis=2, + layers=[ + QuantileRegression('quantile_regression_layer', + init_bias=0.0, + percentile=percentile, + irange=0.1) + ] + ) + train = Train(dataset, mlp, SGD(0.05, batch_size=100)) + train.algorithm.termination_criterion = EpochCounter(100) + train.main_loop() + inputs = mlp.get_input_space().make_theano_batch() + outputs = mlp.fprop(inputs) + theano.function([inputs], outputs, allow_input_downcast=True)(X) + layers = mlp.layers + layer = layers[0] + assert np.allclose(layer.get_weights(), coeffs, rtol=0.05) + assert np.allclose(layer.get_biases(), np.array(percentile), rtol=0.05)