Skip to content

Commit

Permalink
Fixed OOM issues when evaluating/predicting (#721)
Browse files Browse the repository at this point in the history
* Fixed OOM issues when evaluating/predicting

* Fixed test

* Improved error message

* Fixed test
  • Loading branch information
gabrielspmoreira authored and oliverholworthy committed Jun 14, 2023
1 parent 5c943bb commit e2c8eb2
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 45 deletions.
99 changes: 92 additions & 7 deletions tests/unit/torch/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,14 @@ def test_trainer_music_streaming(task_and_metrics):
assert eval_metrics["eval_/loss"] is not None

assert predictions is not None

DEFAULT_PREDICT_TOP_K = 100

# 1000 is the total samples in the testing data
if isinstance(task, tr.NextItemPredictionTask):
assert predictions.predictions.shape == (1000, task.target_dim)
top_predicted_item_ids, top_prediction_scores = predictions.predictions
assert top_predicted_item_ids.shape == (1000, DEFAULT_PREDICT_TOP_K)
assert top_prediction_scores.shape == (1000, DEFAULT_PREDICT_TOP_K)
else:
assert predictions.predictions.shape == (1000,)

Expand Down Expand Up @@ -573,17 +578,79 @@ def test_trainer_with_multiple_tasks(schema_type):
assert predictions.predictions["click/binary_classification_task"].shape == (1000,)


def test_trainer_trop_k_with_wrong_task():
@pytest.mark.parametrize("predict_top_k", [20, None, "default"])
def test_trainer_predict_topk(predict_top_k):
DEFAULT_PREDICT_TOP_K = 100

data = tr.data.music_streaming_testing_data
schema = data.schema
batch_size = 16
predict_top_k = 20

task = tr.BinaryClassificationTask("click", summary_type="mean")
task = tr.NextItemPredictionTask(weight_tying=True)
inputs = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=20,
d_output=64,
masking="clm",
)
transformer_config = tconf.XLNetConfig.build(64, 4, 2, 20)
model = transformer_config.to_torch_model(inputs, task)

additional_args = {}
if not isinstance(predict_top_k, str):
additional_args["predict_top_k"] = predict_top_k

args = trainer.T4RecTrainingArguments(
output_dir=".",
num_train_epochs=1,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size // 2,
data_loader_engine="merlin_dataloader",
max_sequence_length=20,
report_to=[],
debug=["r"],
**additional_args,
)

recsys_trainer = tr.Trainer(
model=model,
args=args,
schema=schema,
train_dataset_or_path=data.path,
eval_dataset_or_path=data.path,
test_dataset_or_path=data.path,
compute_metrics=True,
)

outputs = recsys_trainer.predict(data.path)

if predict_top_k is None:
assert outputs.predictions.shape[1] == 10001
else:
if predict_top_k == "default":
predict_top_k = DEFAULT_PREDICT_TOP_K

pred_item_ids, pred_scores = outputs.predictions
assert len(pred_item_ids.shape) == 2
assert pred_item_ids.shape[1] == predict_top_k
assert len(pred_scores.shape) == 2
assert pred_scores.shape[1] == predict_top_k


@pytest.mark.parametrize("predict_top_k", [15, 20, 30, None])
@pytest.mark.parametrize("top_k", [20, None])
def test_trainer_predict_top_k_x_top_k(predict_top_k, top_k):
data = tr.data.music_streaming_testing_data
schema = data.schema
batch_size = 16

task = tr.NextItemPredictionTask(weight_tying=True)

inputs = tr.TabularSequenceFeatures.from_schema(
schema,
max_sequence_length=20,
d_output=64,
masking="clm",
)
transformer_config = tconf.XLNetConfig.build(64, 4, 2, 20)
model = transformer_config.to_torch_model(inputs, task)
Expand All @@ -609,10 +676,28 @@ def test_trainer_trop_k_with_wrong_task():
test_dataset_or_path=data.path,
compute_metrics=True,
)
with pytest.raises(AssertionError) as excinfo:
recsys_trainer.predict(data.path)

assert "Top-k prediction is specific to NextItemPredictionTask" in str(excinfo.value)
model.top_k = top_k

if predict_top_k and top_k and predict_top_k > top_k:
with pytest.raises(ValueError) as excinfo:
recsys_trainer.predict(data.path)
assert "The args.predict_top_k should not be larger than model.top_k" in str(excinfo.value)

else:
outputs = recsys_trainer.predict(data.path)

if predict_top_k or top_k:
expected_top_k = predict_top_k or top_k

pred_item_ids, pred_scores = outputs.predictions
assert len(pred_item_ids.shape) == 2
assert pred_item_ids.shape[1] == expected_top_k
assert len(pred_scores.shape) == 2
assert pred_scores.shape[1] == expected_top_k
else:
ITEM_CARDINALITY = 10001
assert outputs.predictions.shape[1] == ITEM_CARDINALITY


def test_trainer_with_pretrained_embeddings():
Expand Down
9 changes: 5 additions & 4 deletions transformers4rec/config/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ class T4RecTrainingArguments(TrainingArguments):
predict_top_k: Option[int], int
Truncate recommendation list to the highest top-K predicted items,
(do not affect evaluation metrics computation),
this parameter is specific to NextItemPredictionTask.
by default 0
This parameter is specific to NextItemPredictionTask and only affects
model.predict() and model.evaluate(), which both call `Trainer.evaluation_loop`.
By default 100.
log_predictions : Optional[bool], bool
log predictions, labels and metadata features each --compute_metrics_each_n_steps
(for test set).
Expand Down Expand Up @@ -90,10 +91,10 @@ class T4RecTrainingArguments(TrainingArguments):
)

predict_top_k: int = field(
default=0,
default=100,
metadata={
"help": "Truncate recommendation list to the highest top-K predicted items (do not affect evaluation metrics computation), "
"this parameter is specific to NextItemPredictionTask."
"this parameter is specific to NextItemPredictionTask. Default is 100."
},
)

Expand Down
86 changes: 52 additions & 34 deletions transformers4rec/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,50 +528,68 @@ def evaluation_loop(
if labels_host is None
else nested_concat(labels_host, labels, padding_index=0)
)
if preds is not None and self.args.predict_top_k > 0:
if self.model.top_k:
raise ValueError(
"you cannot set top_k argument in the model class and the, "
"predict_top_k in the trainer at the same time. Please ensure setting "
"only predict_top_k"
)

if (
preds is not None
and any(isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks)
and (self.args.predict_top_k or self.model.top_k)
):
# get outputs of next-item scores
if isinstance(preds, dict):
assert any(
isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks
), "Top-k prediction is specific to NextItemPredictionTask, "
"Please ensure `self.args.predict_top_k == 0` "
pred_next_item = preds["next-item"]
else:
assert isinstance(
model.prediction_tasks[0], NextItemPredictionTask
), "Top-k prediction is specific to NextItemPredictionTask, "
"Please ensure `self.args.predict_top_k == 0` "
pred_next_item = preds

preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
pred_next_item, k=self.args.predict_top_k, dim=-1
)
self._maybe_log_predictions(
labels,
preds_sorted_item_ids,
preds_sorted_item_scores,
# outputs["pred_metadata"],
metrics_results_detailed,
metric_key_prefix,
)
# The output predictions will be a tuple with the ranked top-n item ids,
# and item recommendation scores
if isinstance(preds, dict):
preds["next-item"] = (
preds_sorted_item_ids,
preds_sorted_item_scores,
preds_sorted_item_scores = None
preds_sorted_item_ids = None

if self.model.top_k is not None and isinstance(pred_next_item, (list, tuple)):
preds_sorted_item_scores, preds_sorted_item_ids = pred_next_item

if self.args.predict_top_k:
if self.args.predict_top_k > self.model.top_k:
raise ValueError(
"The args.predict_top_k should not be larger than model.top_k. "
"The model.top_k is available to support inference (e.g. when "
"serving with Triton Inference Server) to return only the top-k "
"predicted items ids and their scores."
"When doing offline predictions with `trainer.predict(), "
"if you set model.top_k, the model will also limit the number of "
"predictions output from trainer.predict(). "
"In that case, you want either to reduce args.predict_top_k or "
"increase model.top_k, so that args.predict_top_k is "
"not larger than model.top_k."
)
preds_sorted_item_scores = preds_sorted_item_scores[
:, : self.args.predict_top_k
]
preds_sorted_item_ids = preds_sorted_item_ids[:, : self.args.predict_top_k]
elif self.args.predict_top_k:
preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
pred_next_item, k=self.args.predict_top_k, dim=-1
)
else:
preds = (

if preds_sorted_item_scores is not None:
self._maybe_log_predictions(
labels,
preds_sorted_item_ids,
preds_sorted_item_scores,
# outputs["pred_metadata"],
metrics_results_detailed,
metric_key_prefix,
)
# The output predictions will be a tuple with the ranked top-n item ids,
# and item recommendation scores
if isinstance(preds, dict):
preds["next-item"] = (
preds_sorted_item_ids,
preds_sorted_item_scores,
)
else:
preds = (
preds_sorted_item_ids,
preds_sorted_item_scores,
)

preds_host = (
preds
Expand Down

0 comments on commit e2c8eb2

Please sign in to comment.