diff --git a/tests/unit/torch/features/test_sequential.py b/tests/unit/torch/features/test_sequential.py index 78c6dc5953..f58eb7cdb5 100644 --- a/tests/unit/torch/features/test_sequential.py +++ b/tests/unit/torch/features/test_sequential.py @@ -136,7 +136,6 @@ def test_sequential_tabular_features_ignore_masking(schema, torch_yoochoose_like input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy() ) - assert np.allclose(output_wo_masking, output_inference_masking, rtol=1e-04, atol=1e-08) assert not np.allclose(output_wo_masking, output_clm_masking, rtol=1e-04, atol=1e-08) input_module._masking = MaskedLanguageModeling(hidden_size=100) diff --git a/tests/unit/torch/test_masking.py b/tests/unit/torch/test_masking.py index cd3542beb1..5a56718509 100644 --- a/tests/unit/torch/test_masking.py +++ b/tests/unit/torch/test_masking.py @@ -43,7 +43,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task): lm = tr.masking.masking_registry[task]( hidden_dim, padding_idx=torch_masking_inputs["padding_idx"] ) - lm.compute_masked_targets(torch_masking_inputs["labels"], training=False) + lm.compute_masked_targets(torch_masking_inputs["labels"], training=False, testing=True) # get non padded last items non_padded_mask = torch_masking_inputs["labels"] != torch_masking_inputs["padding_idx"] rows_ids = torch.arange( @@ -57,7 +57,7 @@ def test_mask_only_last_item_for_eval(torch_masking_inputs, task): trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"] out_last = lm.masked_targets[trgt_pad].flatten().numpy() # check that only one item is masked for each session - assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0) + assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0) # check only the last non-paded item is masked assert all(last_labels == out_last) @@ -109,7 +109,7 @@ def test_clm_training_on_last_item(torch_masking_inputs): # last labels from output trgt_pad = lm.masked_targets != torch_masking_inputs["padding_idx"] out_last = lm.masked_targets[trgt_pad].flatten().numpy() - assert lm.mask_schema.sum() == torch_masking_inputs["input_tensor"].size(0) + assert trgt_pad.sum() == torch_masking_inputs["input_tensor"].size(0) assert all(last_labels == out_last) diff --git a/transformers4rec/torch/masking.py b/transformers4rec/torch/masking.py index e699d52f47..5a059e27de 100644 --- a/transformers4rec/torch/masking.py +++ b/transformers4rec/torch/masking.py @@ -274,6 +274,10 @@ def __init__( def _compute_masked_targets( self, item_ids: torch.Tensor, training: bool = False, testing: bool = False ) -> MaskingInfo: + if not training and not testing: + mask_labels = item_ids != self.padding_idx + return MaskingInfo(mask_labels, item_ids) + masking_info = self.predict_all(item_ids) mask_labels, labels = masking_info.schema, masking_info.targets @@ -290,7 +294,8 @@ def _compute_masked_targets( label_seq_trg_eval[rows_ids, last_item_sessions] = labels[rows_ids, last_item_sessions] # Updating labels and mask labels = label_seq_trg_eval - mask_labels = label_seq_trg_eval != self.padding_idx + # We only mask padded positions + mask_labels = item_ids != self.padding_idx return MaskingInfo(mask_labels, labels) @@ -302,6 +307,13 @@ def apply_mask_to_inputs( testing: bool = False, ) -> torch.Tensor: if not training and not testing: + # Replacing the inputs corresponding to padded items with a trainable embedding + # To mimic training and evaluation masking strategy + inputs = torch.where( + mask_schema.unsqueeze(-1).bool(), + inputs, + self.masked_item_embedding.to(inputs.dtype), + ) return inputs # shift sequence of interaction embeddings pos_emb_inp = inputs[:, :-1] @@ -316,7 +328,7 @@ def apply_mask_to_inputs( ], axis=1, ) - # Replacing the inputs corresponding to masked label with a trainable embedding + # Replacing the inputs corresponding to padded items with a trainable embedding pos_emb_inp = torch.where( mask_schema.unsqueeze(-1).bool(), pos_emb_inp, @@ -601,14 +613,16 @@ def _compute_masked_targets_extended( # from the interval `[cur_len, cur_len + context_length - span_length]` start_index = ( cur_len - + torch.randint( # type: ignore - context_length - span_length + 1, (1,) + + torch.randint( + context_length - span_length + 1, (1,) # type: ignore ).item() ) if start_index < max_len: # Mask the span of non-padded items # `start_index:start_index + span_length` - mask_labels[i, start_index : start_index + span_length] = 1 + mask_labels[ + i, start_index : start_index + span_length # type: ignore + ] = 1 # Set `cur_len = cur_len + context_length` cur_len += context_length # if no item was masked: