Skip to content

Commit

Permalink
fix of clm performance (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
sararb authored and oliverholworthy committed Jun 21, 2023
1 parent e2c8eb2 commit edafe97
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
1 change: 0 additions & 1 deletion tests/unit/torch/features/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def test_sequential_tabular_features_ignore_masking(schema, torch_yoochoose_like
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
)

assert np.allclose(output_wo_masking, output_inference_masking, rtol=1e-04, atol=1e-08)
assert not np.allclose(output_wo_masking, output_clm_masking, rtol=1e-04, atol=1e-08)

input_module._masking = MaskedLanguageModeling(hidden_size=100)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/torch/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
lm = tr.masking.masking_registry[task](
hidden_dim, padding_idx=torch_masking_inputs["padding_idx"]
)
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False)
lm.compute_masked_targets(torch_masking_inputs["labels"], training=False, testing=True)
# get non padded last items
non_padded_mask = torch_masking_inputs["labels"] != torch_masking_inputs["padding_idx"]
rows_ids = torch.arange(
Expand All @@ -57,7 +57,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task):
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
# check that only one item is masked for each session
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
# check only the last non-paded item is masked
assert all(last_labels == out_last)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_clm_training_on_last_item(torch_masking_inputs):
# last labels from output
trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"]
out_last = lm.masked_targets[trgt_pad].flatten().numpy()
assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0)
assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0)
assert all(last_labels == out_last)


Expand Down
24 changes: 19 additions & 5 deletions transformers4rec/torch/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def __init__(
def _compute_masked_targets(
self, item_ids: torch.Tensor, training: bool = False, testing: bool = False
) -> MaskingInfo:
if not training and not testing:
mask_labels = item_ids != self.padding_idx
return MaskingInfo(mask_labels, item_ids)

masking_info = self.predict_all(item_ids)
mask_labels, labels = masking_info.schema, masking_info.targets

Expand All @@ -290,7 +294,8 @@ def _compute_masked_targets(
label_seq_trg_eval[rows_ids, last_item_sessions] = labels[rows_ids, last_item_sessions]
# Updating labels and mask
labels = label_seq_trg_eval
mask_labels = label_seq_trg_eval != self.padding_idx
# We only mask padded positions
mask_labels = item_ids != self.padding_idx

return MaskingInfo(mask_labels, labels)

Expand All @@ -302,6 +307,13 @@ def apply_mask_to_inputs(
testing: bool = False,
) -> torch.Tensor:
if not training and not testing:
# Replacing the inputs corresponding to padded items with a trainable embedding
# To mimic training and evaluation masking strategy
inputs = torch.where(
mask_schema.unsqueeze(-1).bool(),
inputs,
self.masked_item_embedding.to(inputs.dtype),
)
return inputs
# shift sequence of interaction embeddings
pos_emb_inp = inputs[:, :-1]
Expand All @@ -316,7 +328,7 @@ def apply_mask_to_inputs(
],
axis=1,
)
# Replacing the inputs corresponding to masked label with a trainable embedding
# Replacing the inputs corresponding to padded items with a trainable embedding
pos_emb_inp = torch.where(
mask_schema.unsqueeze(-1).bool(),
pos_emb_inp,
Expand Down Expand Up @@ -601,14 +613,16 @@ def _compute_masked_targets_extended(
# from the interval `[cur_len, cur_len + context_length - span_length]`
start_index = (
cur_len
+ torch.randint( # type: ignore
context_length - span_length + 1, (1,)
+ torch.randint(
context_length - span_length + 1, (1,) # type: ignore
).item()
)
if start_index < max_len:
# Mask the span of non-padded items
# `start_index:start_index + span_length`
mask_labels[i, start_index : start_index + span_length] = 1
mask_labels[
i, start_index : start_index + span_length # type: ignore
] = 1
# Set `cur_len = cur_len + context_length`
cur_len += context_length
# if no item was masked:
Expand Down

0 comments on commit edafe97

Please sign in to comment.