Skip to content

Commit

Permalink
feat: implements aap actual #21
Browse files Browse the repository at this point in the history
  • Loading branch information
GangBean committed Aug 2, 2024
1 parent 18fca50 commit ced93d5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
8 changes: 6 additions & 2 deletions data/datasets/s3rec_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def preprocess(self) -> pd.DataFrame:

# load attributes
self.item2attributes = self._load_attributes()
logger.info(f"item2attributes : {len(self.item2attributes)}")

logger.info("done")
return df
Expand All @@ -75,8 +76,11 @@ def _load_attributes(self):
logger.info("load item2attributes...")
df = pd.read_json(os.path.join(self.cfg.data_dir, 'yelp_item2attributes.json')).transpose()
self.attributes_count = df.categories.explode().nunique()

return df.drop(columns=['statecity']).transpose().to_dict()

df = df.drop(columns=['statecity']).transpose().to_dict()
df = {key+1:value for key,value in df.items()}
df.update({0: {'categories': []}})
return df


def _set_num_items_and_num_users(self, df):
Expand Down
12 changes: 9 additions & 3 deletions data/datasets/s3rec_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

class S3RecDataset(Dataset):

def __init__(self, data, num_items=None, train=True):
def __init__(self, data, item2attribute, attributes_count, num_items=None, train=True):
super().__init__()
self.data = data
self.num_items = num_items
self.train = train
self.item2attribute = item2attribute
self.attributes_count = attributes_count

def __len__(self):
return self.data.shape[0]
Expand All @@ -29,17 +31,21 @@ def _negative_sampling(self, behaviors):
def __getitem__(self, user_id):
data = self.data.iloc[user_id,:]
pos_item = data['y'].astype('int64')
aap_actual = np.array([[1 if attriute in self.item2attribute[item]['categories'] else 0 \
for attriute in range(self.attributes_count)] for item in data['X']], dtype='float')
if self.train:
return {
'user_id': user_id,
'X': np.array(data['X'], dtype='int64'),
'pos_item': pos_item,
'neg_item': self._negative_sampling(data['behaviors'])[0]
'neg_item': self._negative_sampling(data['behaviors'])[0],
'aap_actual': aap_actual
}
else:
return {
'user_id': user_id,
'X': np.array(data['X'], dtype='int64'),
'pos_item': pos_item,
'neg_items': np.array(self._negative_sampling(data['behaviors']), dtype='int64')
'neg_items': np.array(self._negative_sampling(data['behaviors']), dtype='int64'),
'aap_actual': aap_actual
}
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ def main(cfg: OmegaConf):
model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users
elif cfg.model_name == 'S3Rec':
train_data, valid_data, test_data = data_pipeline.split(df)
train_dataset = S3RecDataset(train_data, num_items=data_pipeline.num_items)
valid_dataset = S3RecDataset(valid_data, num_items=data_pipeline.num_items)
test_dataset = S3RecDataset(test_data, num_items=data_pipeline.num_items, train=False)
train_dataset = S3RecDataset(train_data, data_pipeline.item2attributes, data_pipeline.attributes_count, num_items=data_pipeline.num_items)
valid_dataset = S3RecDataset(valid_data, data_pipeline.item2attributes, data_pipeline.attributes_count, num_items=data_pipeline.num_items)
test_dataset = S3RecDataset(test_data, data_pipeline.item2attributes, data_pipeline.attributes_count, num_items=data_pipeline.num_items, train=False)
args.update({'test_dataset': test_dataset})
model_info['num_items'], model_info['num_users'] = data_pipeline.num_items, data_pipeline.num_users
else:
Expand Down
7 changes: 2 additions & 5 deletions trainers/s3rec_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def train(self, train_dataloader) -> float:

for data in tqdm(train_dataloader): # sequence
sequences = data['X'].to(self.device)
aap_actual = data['aap_actual'].to(self.device)
# item_masked_sequences
masks, item_masked_sequences = self.item_level_masking(sequences)
# segment_masked_sequences
Expand All @@ -136,11 +137,7 @@ def train(self, train_dataloader) -> float:
item_masked_sequences, segment_masked_sequences, pos_segments, neg_segments)

# AAP: item + atrributes
aap_actual = torch.ones_like(aap_output).to(self.device)
# actual = torch.Tensor([
# [1 if attriute in self.item2attribute[item.item()] else 0 \
# for attriute in range(self.attributes_count)] for item in items]
# ).to(self.device) # (item_chunk_size, attributes_count)
aap_actual = aap_actual * masks.unsqueeze(-1)
## compute unmasked area only
aap_loss = nn.functional.binary_cross_entropy_with_logits(aap_output, aap_actual)

Expand Down

0 comments on commit ced93d5

Please sign in to comment.