From 61bcbf578d8d81d304dadcfec4067fedb2a6f59c Mon Sep 17 00:00:00 2001 From: TensorFlow Ranking Date: Tue, 22 Aug 2023 12:21:56 -0700 Subject: [PATCH] Add PWA@k metric. PiperOrigin-RevId: 559189393 --- tensorflow_ranking/python/metrics.py | 42 +++++++++ tensorflow_ranking/python/metrics_impl.py | 69 +++++++++++++- .../python/metrics_impl_test.py | 92 +++++++++++++++++++ tensorflow_ranking/python/metrics_test.py | 24 +++++ 4 files changed, 226 insertions(+), 1 deletion(-) diff --git a/tensorflow_ranking/python/metrics.py b/tensorflow_ranking/python/metrics.py index 8386ca2..9971267 100644 --- a/tensorflow_ranking/python/metrics.py +++ b/tensorflow_ranking/python/metrics.py @@ -72,6 +72,9 @@ class RankingMetricKey(object): # Hits. For binary relevance. HITS = 'hits' + # Position Weighted Average. + PWA = 'pwa' + def compute_mean(metric_key, labels, @@ -107,6 +110,7 @@ def compute_mean(metric_key, RankingMetricKey.ORDERED_PAIR_ACCURACY: metrics_impl.OPAMetric(name), RankingMetricKey.BPREF: metrics_impl.BPrefMetric(name, topn), RankingMetricKey.HITS: metrics_impl.HitsMetric(metric_key, topn), + RankingMetricKey.PWA: metrics_impl.PWAMetric(metric_key, topn), } assert metric_key in metric_dict, ('metric_key %s not supported.' % metric_key) @@ -268,6 +272,15 @@ def _hits_fn(labels, predictions, features): topn=topn, name=name) + def _pwa_fn(labels, predictions, features): + """Returns pwa as the metric.""" + return pwa( + labels, + predictions, + weights=_get_weights(features), + topn=topn, + name=name) + metric_fn_dict = { RankingMetricKey.ARP: _average_relevance_position_fn, RankingMetricKey.MRR: _mean_reciprocal_rank_fn, @@ -281,6 +294,7 @@ def _hits_fn(labels, predictions, features): RankingMetricKey.ALPHA_DCG: _alpha_discounted_cumulative_gain_fn, RankingMetricKey.BPREF: _binary_preference_fn, RankingMetricKey.HITS: _hits_fn, + RankingMetricKey.PWA: _pwa_fn, } assert metric_key in metric_fn_dict, ('metric_key %s not supported.' % metric_key) @@ -716,3 +730,31 @@ def hits(labels, # TODO: Add mask argument for metric.compute() call hits_value, per_list_weights = metric.compute(labels, predictions, weights) return tf.compat.v1.metrics.mean(hits_value, per_list_weights) + + +def pwa(labels, + predictions, + weights=None, + topn=None, + name=None): + """Computes PWA. + + Args: + labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a + relevant example. + predictions: A `Tensor` with shape [batch_size, list_size]. Each value is + the ranking score of the corresponding example. + weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The + former case is per-example and the latter case is per-list. + topn: An integer cutoff specifying how many examples to consider for this + metric. If None, the whole list is considered. + name: A string used as the name for this metric. + + Returns: + A metric for the Position Weighted Average of the batch. + """ + metric = metrics_impl.PWAMetric(name, topn) + with tf.compat.v1.name_scope(metric.name, 'pwa', + (labels, predictions, weights)): + pwa_value, per_list_weights = metric.compute(labels, predictions, weights) + return tf.compat.v1.metrics.mean(pwa_value, per_list_weights) diff --git a/tensorflow_ranking/python/metrics_impl.py b/tensorflow_ranking/python/metrics_impl.py index 07b8b16..a6c8968 100644 --- a/tensorflow_ranking/python/metrics_impl.py +++ b/tensorflow_ranking/python/metrics_impl.py @@ -163,7 +163,7 @@ def _per_list_recall(labels, predictions, topn, mask): mask: A mask indicating which entries are valid for computing the metric. Returns: - A `Tensor` of size [batch_size, 1] containing the precision of each query + A `Tensor` of size [batch_size, 1] containing the recall of each query respectively. """ sorted_labels = utils.sort_by_scores(predictions, [labels], topn=topn, @@ -896,3 +896,70 @@ def _compute_impl(self, labels, predictions, weights, mask): relevance=tf.cast(tf.greater_equal(relevance, 1.0), dtype=tf.float32)) return bpref, per_list_weights + + +class PWAMetric(_RankingMetric): + """Construct a custom Position-Weighted Average Metric. + + For each query we order the results by scores and compute: + + pwa = (ratings[0] * position_weights[0] + ... + + ratings[topn - 1] * position_weights[topn - 1]) / + (position_weights[0] + ... + position_weights[topn - 1]) + + where position_weights = (1. / 1, 1. / 2, ..., 1. / topn) + + Metric value for the whole dataset is weighted sum over pwa values for + individual queries: + + result = pwa(query_0) * weights[0] + pwa(query_1) * weights[1] + ... + + For this metrcs, weights should be a `Tensor` of the shape [batch_size, 1]. + """ + + def __init__(self, name, topn=5, ragged=False): + """Constructor.""" + super().__init__(ragged=ragged) + self._name = name + self._topn = topn + + @property + def name(self): + """The metric name.""" + return self._name + + def compute(self, labels, predictions, weights=None, mask=None): + """See `_RankingMetric`.""" + if weights is not None: + weights_tensor = tf.convert_to_tensor(value=weights) + predictions_tensor = tf.convert_to_tensor(value=predictions) + expected_shape = tf.zeros([tf.shape(predictions_tensor)[0], 1]) + if not weights_tensor.shape.is_compatible_with(expected_shape.shape): + raise ValueError('Weights should be a `Tensor` of the shape' + '[batch_size, 1]') + return super().compute(labels, predictions, weights, mask) + + def _compute_impl(self, labels, predictions, weights, mask): + """See `_RankingMetric`.""" + topn = tf.shape(predictions)[1] if self._topn is None else self._topn + sorted_labels, sorted_mask = utils.sort_by_scores( + predictions, [labels, mask], topn=topn, mask=mask) + + sorted_list_size = tf.shape(input=sorted_labels)[1] + position_weights = 1.0 / tf.cast( + tf.range(1, sorted_list_size + 1), dtype=tf.float32) + masked_position_weights = (tf.cast(sorted_mask, dtype=tf.float32) + * position_weights) + pwa = tf.compat.v1.math.divide_no_nan( + tf.reduce_sum(input_tensor=tf.multiply(sorted_labels, + masked_position_weights), + axis=1, keepdims=True), + tf.reduce_sum(input_tensor=masked_position_weights, + axis=1, keepdims=True)) + # Weights list should come in with size [batch_size, 1], then will be + # expanded out to [batch_size, list_size] in the + # "_prepare_and_validate_params" step, so we need to reduce the Tensor back + # to size [batch_size, 1]. + per_list_weights = tf.reduce_mean( + input_tensor=weights, axis=1, keepdims=True) + return pwa, per_list_weights diff --git a/tensorflow_ranking/python/metrics_impl_test.py b/tensorflow_ranking/python/metrics_impl_test.py index d6d1fac..352ba92 100644 --- a/tensorflow_ranking/python/metrics_impl_test.py +++ b/tensorflow_ranking/python/metrics_impl_test.py @@ -1612,5 +1612,97 @@ def test_bpref_should_give_a_value_for_each_list_in_batch_inputs(self): self.assertAllClose([[0.], [1.]], output) +class PWAMetricTest(tf.test.TestCase): + + def test_pwa_should_be_single_value(self): + scores = [[1., 3., 2.], [4., 3., 2.], [8., 7., 6.]] + labels = [[0., 1., 2.], [4., 3., 2.], [8., 7., 6.]] + + metric = metrics_impl.PWAMetric(name=None, topn=5) + output, _ = metric.compute(labels, scores, None) + + self.assertAllClose(output, [[2. / (11. / 6.)], + [(4. / 1. + 3. / 2. + 2./ 3.) / + (1. / 1. + 1. / 2. + 1. / 3.)], + [(8. / 1. + 7. / 2. + 6./ 3.) / + (1. / 1. + 1. / 2. + 1. / 3.)]]) + + def test_pwa_should_be_0_when_no_rel_item(self): + scores = [[1., 3., 2.]] + labels = [[0., 0., 0.]] + + metric = metrics_impl.PWAMetric(name=None, topn=None) + output, _ = metric.compute(labels, scores, None) + + self.assertAllClose(output, [[0.]]) + + def test_pwa_should_be_0_when_no_rel_item_in_topn(self): + scores = [[1., 3., 2.]] + labels = [[0., 0., 1.]] + + metric = metrics_impl.PWAMetric(name=None, topn=1) + output, _ = metric.compute(labels, scores, None) + + self.assertAllClose(output, [[0.]]) + + def test_pwa_should_handle_topn(self): + scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]] + labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]] + + metric_top1 = metrics_impl.PWAMetric(name=None, topn=1) + metric_top2 = metrics_impl.PWAMetric(name=None, topn=2) + metric_top6 = metrics_impl.PWAMetric(name=None, topn=6) + output_top1, _ = metric_top1.compute(labels, scores, None) + output_top2, _ = metric_top2.compute(labels, scores, None) + output_top6, _ = metric_top6.compute(labels, scores, None) + + self.assertAllClose(output_top1, [[1.], [2.], [0.]]) + self.assertAllClose(output_top2, [[2. / (3. / 2.)], + [2. / (3. / 2.)], [1. / (3. / 2.)]]) + self.assertAllClose(output_top6, [[2. / (11. / 6.)], [7. / 3. / (11. / 6.)], + [4. / 3. / (11. / 6.)]]) + + def test_pwa_should_ignore_masked_items(self): + scores = [[1., 2., 3.]] + labels = [[0., 1., 3.]] + mask = [[True, False, True]] + + metric = metrics_impl.PWAMetric(name=None, topn=None) + output, _ = metric.compute(labels, scores, None, mask=mask) + + self.assertAllClose(output, [[3. / 1. / (3. / 2.)]]) + + def test_pwa_weights_no_input_weights(self): + scores = [[1., 3., 2.], [1., 2., 3.]] + labels = [[1., 4., 2.], [0., 3., 1.]] + + metric = metrics_impl.PWAMetric(name=None, topn=None) + _, output_weights = metric.compute(labels, scores) + + self.assertAllClose(output_weights, [[1.], [1.]]) + + def test_pwa_weights_should_be_average_weight_of_rel_items(self): + scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]] + labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]] + weights = [[2.], [3.], [4.]] + + metric = metrics_impl.PWAMetric(name=None, topn=None) + output_pwa, output_weights = metric.compute(labels, scores, weights) + + self.assertAllClose(output_weights, [[2.], [3.], [4.]]) + self.assertAllClose(output_pwa, [[2. / (11. / 6.)], + [7. / 3. / (11. / 6.)], + [4. / 3. / (11. / 6.)]]) + + def test_pwa_weights_should_raise_error_if_per_result(self): + scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]] + labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]] + weights = [[2., 3., 2.], [3., 7., 3.], [8., 1., 4.]] + + metric = metrics_impl.PWAMetric(name=None, topn=None) + with self.assertRaises(ValueError): + metric.compute(labels, scores, weights) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_ranking/python/metrics_test.py b/tensorflow_ranking/python/metrics_test.py index 6e95b30..8874b58 100644 --- a/tensorflow_ranking/python/metrics_test.py +++ b/tensorflow_ranking/python/metrics_test.py @@ -200,6 +200,30 @@ def test_make_mean_reciprocal_rank_fn(self): features), (sum([0., 1. / rel_rank[1], 0.]) / num_queries)), ]) + def test_make_pwa_fn(self): + with tf.Graph().as_default(): + scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]] + labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]] + weights = [[2.], [3.], [4.]] + num_queries = len(scores) + weights_feature_name = 'weights' + features = {weights_feature_name: weights} + m = metrics_lib.make_ranking_metric_fn(metrics_lib.RankingMetricKey.PWA) + m_2 = metrics_lib.make_ranking_metric_fn( + metrics_lib.RankingMetricKey.PWA, topn=2) + m_w = metrics_lib.make_ranking_metric_fn( + metrics_lib.RankingMetricKey.PWA, + weights_feature_name=weights_feature_name) + self._check_metrics([ + (m(labels, scores, features), (((2. + 7. / 3. + 4. / 3.) / (11. / 6.)) + / num_queries)), + (m_2(labels, scores, features), (((2. + 2. + 1.) / (3. / 2.)) + / num_queries)), + (m_w(labels, scores, features), + ((2. * 2. + 3. * 7. / 3. + 4 * 4. / 3.) / (11. / 6.)) + / (2. + 3. + 4.)), + ]) + def test_make_hits_fn(self): with tf.Graph().as_default(): scores = [[1., 3., 2.], [1., 2., 3.]]