Skip to content

Commit

Permalink
Fixes on RecallAt metric for cut-off 1 and for scalar cut-off (#720)
Browse files Browse the repository at this point in the history
* Fixed issues on Recall when a cut-off 1 was used and when a scalar cut-off was provided

* Fixed exception when using a RankingMetric with a single cut-off (torch metrics was converting the result to scalar)
  • Loading branch information
gabrielspmoreira authored and oliverholworthy committed Jun 14, 2023
1 parent f4946bf commit 5c943bb
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 12 deletions.
54 changes: 53 additions & 1 deletion tests/unit/torch/test_ranking_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

import transformers4rec.torch as tr
from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt
from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt, RecallAt

# fixed parameters for tests
list_metrics = list(tr.ranking_metric.ranking_metrics_registry.keys())
Expand Down Expand Up @@ -64,6 +64,58 @@ def test_mean_recipricol_rank():
)


def test_recall_at():
metric = RecallAt([1, 2, 3, 4], labels_onehot=False)
result = metric(
torch.tensor(
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]]
),
torch.tensor(
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0]]
),
)
assert torch.all(
torch.lt(
torch.abs(torch.add(result, -torch.tensor([0.3333, 0.3333, 0.6667, 0.6667]))), 1e-3
)
)


def test_recall_at_3d():
metric = RecallAt([1, 2, 3, 4], labels_onehot=False)
result = metric(
torch.tensor(
[
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]],
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]],
]
),
torch.tensor(
[
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0]],
[[0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0]],
]
),
)
assert torch.all(
torch.lt(torch.abs(torch.add(result, -torch.tensor([0.25, 0.25, 0.75, 0.75]))), 1e-3)
)


@pytest.mark.parametrize("cutoff", [4, [4]])
def test_recall_at_single_metric(cutoff):
metric = RecallAt(cutoff, labels_onehot=False)
result = metric(
torch.tensor(
[[1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1], [1, 2, 3, 4, 5, 4, 3, 2, 1]]
),
torch.tensor(
[[0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0]]
),
)
assert torch.all(torch.lt(torch.abs(torch.add(result, -torch.tensor([0.6667]))), 1e-3))


# TODO: Compare the metrics @K between pytorch and numpy
@pytest.mark.parametrize("metric", list_metrics)
def test_numpy_comparison(metric):
Expand Down
19 changes: 17 additions & 2 deletions tests/unit/torch/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import transformers4rec.torch as tr
from transformers4rec.config import trainer
from transformers4rec.config import transformer as tconf
from transformers4rec.torch.ranking_metric import NDCGAt, RecallAt


@pytest.mark.parametrize("batch_size", [16, 32])
Expand Down Expand Up @@ -327,6 +328,20 @@ def test_evaluate_results(torch_yoochoose_next_item_prediction_model):
"eval_/next-item/recall_at_20",
],
),
(
tr.NextItemPredictionTask(
weight_tying=False,
metrics=(
NDCGAt(top_ks=[5, 10], labels_onehot=True),
RecallAt(top_ks=[10], labels_onehot=True),
),
),
[
"eval_/next-item/ndcg_at_5",
"eval_/next-item/ndcg_at_10",
"eval_/next-item/recall_at_10",
],
),
(
tr.BinaryClassificationTask("click", summary_type="mean"),
[
Expand All @@ -347,7 +362,7 @@ def test_trainer_music_streaming(task_and_metrics):
data = tr.data.music_streaming_testing_data
schema = data.schema
batch_size = 16
task, default_metric = task_and_metrics
task, expected_metrics = task_and_metrics

inputs = tr.TabularSequenceFeatures.from_schema(
schema,
Expand Down Expand Up @@ -388,7 +403,7 @@ def test_trainer_music_streaming(task_and_metrics):
predictions = recsys_trainer.predict(data.path)

assert isinstance(eval_metrics, dict)
assert set(default_metric).issubset(set(eval_metrics.keys()))
assert set(expected_metrics).issubset(set(eval_metrics.keys()))
assert eval_metrics["eval_/loss"] is not None

assert predictions is not None
Expand Down
7 changes: 6 additions & 1 deletion transformers4rec/torch/model/prediction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def calculate_metrics(self, predictions, targets) -> Dict[str, torch.Tensor]: #
predictions = self.forward_to_prediction_fn(predictions)

for metric in self.metrics:
outputs[self.metric_name(metric)] = metric(predictions, targets)
result = metric(predictions, targets)
outputs[self.metric_name(metric)] = result

return outputs

Expand All @@ -502,6 +503,10 @@ def compute_metrics(self):
topks = {self.metric_name(metric): metric.top_ks for metric in self.metrics}
results = {}
for name, metric in metrics.items():
# Fix for when using a single cut-off, as torch metrics convert results to scalar
# when a single element vector is returned
if len(metric.size()) == 0:
metric = metric.unsqueeze(0)
for measure, k in zip(metric, topks[name]):
results[f"{name}_{k}"] = measure
return results
Expand Down
21 changes: 13 additions & 8 deletions transformers4rec/torch/ranking_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class RankingMetric(tm.Metric):
def __init__(self, top_ks=None, labels_onehot=False):
super(RankingMetric, self).__init__()
self.top_ks = top_ks or [2, 5]
if not isinstance(self.top_ks, (list, tuple)):
self.top_ks = [self.top_ks]

self.labels_onehot = labels_onehot
# Store the mean of the batch metrics (for each cut-off at topk)
self.add_state("metric_mean", default=[], dist_reduce_fx="cat")
Expand All @@ -50,7 +53,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor, **kwargs): # type:
# Computing the metrics at different cut-offs
if self.labels_onehot:
target = torch_utils.tranform_label_to_onehot(target, preds.size(-1))
metric = self._metric(self.top_ks, preds.view(-1, preds.size(-1)), target)
metric = self._metric(
self.top_ks, preds.view(-1, preds.size(-1)), target.view(-1, target.size(-1))
)
self.metric_mean.append(metric) # type: ignore

def compute(self):
Expand Down Expand Up @@ -126,17 +131,17 @@ def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor)

# Compute recalls at K
num_relevant = torch.sum(labels, dim=-1)
rel_indices = (num_relevant != 0).nonzero()
rel_count = num_relevant[rel_indices].squeeze()
rel_indices = (num_relevant != 0).nonzero().squeeze()
rel_count = num_relevant[rel_indices]

if rel_indices.shape[0] > 0:
for index, k in enumerate(ks):
rel_labels = topk_labels[rel_indices, : int(k)].squeeze()
rel_labels = topk_labels[rel_indices, : int(k)]

recalls[rel_indices, index] = (
torch.div(torch.sum(rel_labels, dim=-1), rel_count)
.reshape(len(rel_indices), 1)
.to(dtype=torch.float32)
recalls[rel_indices, index] = torch.div(
torch.sum(rel_labels, dim=-1), rel_count
).to(
dtype=torch.float32
) # Ensuring type is double, because it can be float if --fp16

return recalls
Expand Down

0 comments on commit 5c943bb

Please sign in to comment.