Skip to content

Commit

Permalink
Add PWA@k metric.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559189393
  • Loading branch information
TensorFlow Ranking authored and lyyanlely committed Sep 26, 2023
1 parent 3e35c5c commit 61bcbf5
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 1 deletion.
42 changes: 42 additions & 0 deletions tensorflow_ranking/python/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class RankingMetricKey(object):
# Hits. For binary relevance.
HITS = 'hits'

# Position Weighted Average.
PWA = 'pwa'


def compute_mean(metric_key,
labels,
Expand Down Expand Up @@ -107,6 +110,7 @@ def compute_mean(metric_key,
RankingMetricKey.ORDERED_PAIR_ACCURACY: metrics_impl.OPAMetric(name),
RankingMetricKey.BPREF: metrics_impl.BPrefMetric(name, topn),
RankingMetricKey.HITS: metrics_impl.HitsMetric(metric_key, topn),
RankingMetricKey.PWA: metrics_impl.PWAMetric(metric_key, topn),
}
assert metric_key in metric_dict, ('metric_key %s not supported.' %
metric_key)
Expand Down Expand Up @@ -268,6 +272,15 @@ def _hits_fn(labels, predictions, features):
topn=topn,
name=name)

def _pwa_fn(labels, predictions, features):
"""Returns pwa as the metric."""
return pwa(
labels,
predictions,
weights=_get_weights(features),
topn=topn,
name=name)

metric_fn_dict = {
RankingMetricKey.ARP: _average_relevance_position_fn,
RankingMetricKey.MRR: _mean_reciprocal_rank_fn,
Expand All @@ -281,6 +294,7 @@ def _hits_fn(labels, predictions, features):
RankingMetricKey.ALPHA_DCG: _alpha_discounted_cumulative_gain_fn,
RankingMetricKey.BPREF: _binary_preference_fn,
RankingMetricKey.HITS: _hits_fn,
RankingMetricKey.PWA: _pwa_fn,
}
assert metric_key in metric_fn_dict, ('metric_key %s not supported.' %
metric_key)
Expand Down Expand Up @@ -716,3 +730,31 @@ def hits(labels,
# TODO: Add mask argument for metric.compute() call
hits_value, per_list_weights = metric.compute(labels, predictions, weights)
return tf.compat.v1.metrics.mean(hits_value, per_list_weights)


def pwa(labels,
predictions,
weights=None,
topn=None,
name=None):
"""Computes PWA.
Args:
labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
relevant example.
predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
the ranking score of the corresponding example.
weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
former case is per-example and the latter case is per-list.
topn: An integer cutoff specifying how many examples to consider for this
metric. If None, the whole list is considered.
name: A string used as the name for this metric.
Returns:
A metric for the Position Weighted Average of the batch.
"""
metric = metrics_impl.PWAMetric(name, topn)
with tf.compat.v1.name_scope(metric.name, 'pwa',
(labels, predictions, weights)):
pwa_value, per_list_weights = metric.compute(labels, predictions, weights)
return tf.compat.v1.metrics.mean(pwa_value, per_list_weights)
69 changes: 68 additions & 1 deletion tensorflow_ranking/python/metrics_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _per_list_recall(labels, predictions, topn, mask):
mask: A mask indicating which entries are valid for computing the metric.
Returns:
A `Tensor` of size [batch_size, 1] containing the precision of each query
A `Tensor` of size [batch_size, 1] containing the recall of each query
respectively.
"""
sorted_labels = utils.sort_by_scores(predictions, [labels], topn=topn,
Expand Down Expand Up @@ -896,3 +896,70 @@ def _compute_impl(self, labels, predictions, weights, mask):
relevance=tf.cast(tf.greater_equal(relevance, 1.0), dtype=tf.float32))

return bpref, per_list_weights


class PWAMetric(_RankingMetric):
"""Construct a custom Position-Weighted Average Metric.
For each query we order the results by scores and compute:
pwa = (ratings[0] * position_weights[0] + ...
+ ratings[topn - 1] * position_weights[topn - 1]) /
(position_weights[0] + ... + position_weights[topn - 1])
where position_weights = (1. / 1, 1. / 2, ..., 1. / topn)
Metric value for the whole dataset is weighted sum over pwa values for
individual queries:
result = pwa(query_0) * weights[0] + pwa(query_1) * weights[1] + ...
For this metrcs, weights should be a `Tensor` of the shape [batch_size, 1].
"""

def __init__(self, name, topn=5, ragged=False):
"""Constructor."""
super().__init__(ragged=ragged)
self._name = name
self._topn = topn

@property
def name(self):
"""The metric name."""
return self._name

def compute(self, labels, predictions, weights=None, mask=None):
"""See `_RankingMetric`."""
if weights is not None:
weights_tensor = tf.convert_to_tensor(value=weights)
predictions_tensor = tf.convert_to_tensor(value=predictions)
expected_shape = tf.zeros([tf.shape(predictions_tensor)[0], 1])
if not weights_tensor.shape.is_compatible_with(expected_shape.shape):
raise ValueError('Weights should be a `Tensor` of the shape'
'[batch_size, 1]')
return super().compute(labels, predictions, weights, mask)

def _compute_impl(self, labels, predictions, weights, mask):
"""See `_RankingMetric`."""
topn = tf.shape(predictions)[1] if self._topn is None else self._topn
sorted_labels, sorted_mask = utils.sort_by_scores(
predictions, [labels, mask], topn=topn, mask=mask)

sorted_list_size = tf.shape(input=sorted_labels)[1]
position_weights = 1.0 / tf.cast(
tf.range(1, sorted_list_size + 1), dtype=tf.float32)
masked_position_weights = (tf.cast(sorted_mask, dtype=tf.float32)
* position_weights)
pwa = tf.compat.v1.math.divide_no_nan(
tf.reduce_sum(input_tensor=tf.multiply(sorted_labels,
masked_position_weights),
axis=1, keepdims=True),
tf.reduce_sum(input_tensor=masked_position_weights,
axis=1, keepdims=True))
# Weights list should come in with size [batch_size, 1], then will be
# expanded out to [batch_size, list_size] in the
# "_prepare_and_validate_params" step, so we need to reduce the Tensor back
# to size [batch_size, 1].
per_list_weights = tf.reduce_mean(
input_tensor=weights, axis=1, keepdims=True)
return pwa, per_list_weights
92 changes: 92 additions & 0 deletions tensorflow_ranking/python/metrics_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,5 +1612,97 @@ def test_bpref_should_give_a_value_for_each_list_in_batch_inputs(self):
self.assertAllClose([[0.], [1.]], output)


class PWAMetricTest(tf.test.TestCase):

def test_pwa_should_be_single_value(self):
scores = [[1., 3., 2.], [4., 3., 2.], [8., 7., 6.]]
labels = [[0., 1., 2.], [4., 3., 2.], [8., 7., 6.]]

metric = metrics_impl.PWAMetric(name=None, topn=5)
output, _ = metric.compute(labels, scores, None)

self.assertAllClose(output, [[2. / (11. / 6.)],
[(4. / 1. + 3. / 2. + 2./ 3.) /
(1. / 1. + 1. / 2. + 1. / 3.)],
[(8. / 1. + 7. / 2. + 6./ 3.) /
(1. / 1. + 1. / 2. + 1. / 3.)]])

def test_pwa_should_be_0_when_no_rel_item(self):
scores = [[1., 3., 2.]]
labels = [[0., 0., 0.]]

metric = metrics_impl.PWAMetric(name=None, topn=None)
output, _ = metric.compute(labels, scores, None)

self.assertAllClose(output, [[0.]])

def test_pwa_should_be_0_when_no_rel_item_in_topn(self):
scores = [[1., 3., 2.]]
labels = [[0., 0., 1.]]

metric = metrics_impl.PWAMetric(name=None, topn=1)
output, _ = metric.compute(labels, scores, None)

self.assertAllClose(output, [[0.]])

def test_pwa_should_handle_topn(self):
scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]
labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]]

metric_top1 = metrics_impl.PWAMetric(name=None, topn=1)
metric_top2 = metrics_impl.PWAMetric(name=None, topn=2)
metric_top6 = metrics_impl.PWAMetric(name=None, topn=6)
output_top1, _ = metric_top1.compute(labels, scores, None)
output_top2, _ = metric_top2.compute(labels, scores, None)
output_top6, _ = metric_top6.compute(labels, scores, None)

self.assertAllClose(output_top1, [[1.], [2.], [0.]])
self.assertAllClose(output_top2, [[2. / (3. / 2.)],
[2. / (3. / 2.)], [1. / (3. / 2.)]])
self.assertAllClose(output_top6, [[2. / (11. / 6.)], [7. / 3. / (11. / 6.)],
[4. / 3. / (11. / 6.)]])

def test_pwa_should_ignore_masked_items(self):
scores = [[1., 2., 3.]]
labels = [[0., 1., 3.]]
mask = [[True, False, True]]

metric = metrics_impl.PWAMetric(name=None, topn=None)
output, _ = metric.compute(labels, scores, None, mask=mask)

self.assertAllClose(output, [[3. / 1. / (3. / 2.)]])

def test_pwa_weights_no_input_weights(self):
scores = [[1., 3., 2.], [1., 2., 3.]]
labels = [[1., 4., 2.], [0., 3., 1.]]

metric = metrics_impl.PWAMetric(name=None, topn=None)
_, output_weights = metric.compute(labels, scores)

self.assertAllClose(output_weights, [[1.], [1.]])

def test_pwa_weights_should_be_average_weight_of_rel_items(self):
scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]
labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]]
weights = [[2.], [3.], [4.]]

metric = metrics_impl.PWAMetric(name=None, topn=None)
output_pwa, output_weights = metric.compute(labels, scores, weights)

self.assertAllClose(output_weights, [[2.], [3.], [4.]])
self.assertAllClose(output_pwa, [[2. / (11. / 6.)],
[7. / 3. / (11. / 6.)],
[4. / 3. / (11. / 6.)]])

def test_pwa_weights_should_raise_error_if_per_result(self):
scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]
labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]]
weights = [[2., 3., 2.], [3., 7., 3.], [8., 1., 4.]]

metric = metrics_impl.PWAMetric(name=None, topn=None)
with self.assertRaises(ValueError):
metric.compute(labels, scores, weights)


if __name__ == '__main__':
tf.test.main()
24 changes: 24 additions & 0 deletions tensorflow_ranking/python/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def test_make_mean_reciprocal_rank_fn(self):
features), (sum([0., 1. / rel_rank[1], 0.]) / num_queries)),
])

def test_make_pwa_fn(self):
with tf.Graph().as_default():
scores = [[3., 2., 1.], [3., 2., 1.], [3., 2., 1.]]
labels = [[1., 2., 0.], [2., 0., 1.], [0., 2., 1.]]
weights = [[2.], [3.], [4.]]
num_queries = len(scores)
weights_feature_name = 'weights'
features = {weights_feature_name: weights}
m = metrics_lib.make_ranking_metric_fn(metrics_lib.RankingMetricKey.PWA)
m_2 = metrics_lib.make_ranking_metric_fn(
metrics_lib.RankingMetricKey.PWA, topn=2)
m_w = metrics_lib.make_ranking_metric_fn(
metrics_lib.RankingMetricKey.PWA,
weights_feature_name=weights_feature_name)
self._check_metrics([
(m(labels, scores, features), (((2. + 7. / 3. + 4. / 3.) / (11. / 6.))
/ num_queries)),
(m_2(labels, scores, features), (((2. + 2. + 1.) / (3. / 2.))
/ num_queries)),
(m_w(labels, scores, features),
((2. * 2. + 3. * 7. / 3. + 4 * 4. / 3.) / (11. / 6.))
/ (2. + 3. + 4.)),
])

def test_make_hits_fn(self):
with tf.Graph().as_default():
scores = [[1., 3., 2.], [1., 2., 3.]]
Expand Down

0 comments on commit 61bcbf5

Please sign in to comment.