diff --git a/data/datasets/s3rec_data_pipeline.py b/data/datasets/s3rec_data_pipeline.py index a2e881c..ecf5547 100644 --- a/data/datasets/s3rec_data_pipeline.py +++ b/data/datasets/s3rec_data_pipeline.py @@ -28,9 +28,9 @@ def split(self, df: pd.DataFrame): valid_df_X = self._adjust_seq_len(valid_df_X) test_df_X = self._adjust_seq_len(test_df_X) - return pd.concat([train_df_X, train_df_Y], axis=1),\ - pd.concat([valid_df_X, valid_df_Y], axis=1),\ - pd.concat([test_df_X, test_df_Y], axis=1) + return pd.concat([df, train_df_X, train_df_Y], axis=1),\ + pd.concat([df, valid_df_X, valid_df_Y], axis=1),\ + pd.concat([df, test_df_X, test_df_Y], axis=1) def _adjust_seq_len(self, df): diff --git a/data/datasets/s3rec_dataset.py b/data/datasets/s3rec_dataset.py index e7b657d..cbd575a 100644 --- a/data/datasets/s3rec_dataset.py +++ b/data/datasets/s3rec_dataset.py @@ -16,12 +16,12 @@ def __init__(self, data, num_items=None, train=True): def __len__(self): return self.data.shape[0] - def _negative_sampling(self, pos_item): + def _negative_sampling(self, behaviors): sample_size = 1 if self.train else 99 neg_items = [] for _ in range(sample_size): neg_item = np.random.randint(1, self.num_items+1) - while (neg_item == pos_item) or (neg_item in neg_items): + while (neg_item in behaviors) or (neg_item in neg_items): neg_item = np.random.randint(1, self.num_items+1) neg_items.append(neg_item) return neg_items @@ -34,12 +34,12 @@ def __getitem__(self, user_id): 'user_id': user_id, 'X': data['X'], 'pos_item': pos_item, - 'neg_item': self._negative_sampling(pos_item)[0] + 'neg_item': self._negative_sampling(data['behaviors'])[0] } else: return { 'user_id': user_id, 'X': data['X'], 'pos_item': pos_item, - 'neg_items': self._negative_sampling(pos_item) + 'neg_items': self._negative_sampling(data['behaviors']) }