Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the CLM performance mismatch between evaluation and manual inference #723

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/unit/torch/features/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def test_sequential_tabular_features_ignore_masking(schema, torch_yoochoose_like
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
)

assert np.allclose(output_wo_masking, output_inference_masking, rtol=1e-04, atol=1e-08)
assert not np.allclose(output_wo_masking, output_clm_masking, rtol=1e-04, atol=1e-08)

input_module._masking = MaskedLanguageModeling(hidden_size=100)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/torch/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
lm = tr.masking.masking_registry[task](
hidden_dim, padding_idx=torch_masking_inputs["padding_idx"]
)
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False)
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False, testing=True)
# get non padded last items
non_padded_mask = torch_masking_inputs["labels"] != torch_masking_inputs["padding_idx"]
rows_ids = torch.arange(
Expand All @@ -57,7 +57,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
# check that only one item is masked for each session
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
# check only the last non-paded item is masked
assert all(last_labels == out_last)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_clm_training_on_last_item(torch_masking_inputs):
# last labels from output
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
assert all(last_labels == out_last)


Expand Down
24 changes: 19 additions & 5 deletions transformers4rec/torch/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def __init__(
def _compute_masked_targets(
self, item_ids: torch.Tensor, training: bool = False, testing: bool = False
) -> MaskingInfo:
if not training and not testing:
mask_labels = item_ids != self.padding_idx
return MaskingInfo(mask_labels, item_ids)

masking_info = self.predict_all(item_ids)
mask_labels, labels = masking_info.schema, masking_info.targets

Expand All @@ -290,7 +294,8 @@ def _compute_masked_targets(
label_seq_trg_eval[rows_ids, last_item_sessions] = labels[rows_ids, last_item_sessions]
# Updating labels and mask
labels = label_seq_trg_eval
mask_labels = label_seq_trg_eval != self.padding_idx
# We only mask padded positions
mask_labels = item_ids != self.padding_idx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!


return MaskingInfo(mask_labels, labels)

Expand All @@ -302,6 +307,13 @@ def apply_mask_to_inputs(
testing: bool = False,
) -> torch.Tensor:
if not training and not testing:
# Replacing the inputs corresponding to padded items with a trainable embedding
# To mimic training and evaluation masking strategy
inputs = torch.where(
mask_schema.unsqueeze(-1).bool(),
inputs,
self.masked_item_embedding.to(inputs.dtype),
)
return inputs
# shift sequence of interaction embeddings
pos_emb_inp = inputs[:, :-1]
Expand All @@ -316,7 +328,7 @@ def apply_mask_to_inputs(
],
axis=1,
)
# Replacing the inputs corresponding to masked label with a trainable embedding
# Replacing the inputs corresponding to padded items with a trainable embedding
pos_emb_inp = torch.where(
mask_schema.unsqueeze(-1).bool(),
pos_emb_inp,
Expand Down Expand Up @@ -601,14 +613,16 @@ def _compute_masked_targets_extended(
# from the interval `[cur_len, cur_len + context_length - span_length]`
start_index = (
cur_len
+ torch.randint( # type: ignore
context_length - span_length + 1, (1,)
+ torch.randint(
context_length - span_length + 1, (1,) # type: ignore
).item()
)
if start_index < max_len:
# Mask the span of non-padded items
# `start_index:start_index + span_length`
mask_labels[i, start_index : start_index + span_length] = 1
mask_labels[
i, start_index : start_index + span_length # type: ignore
] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length
# if no item was masked:
Expand Down