From ce309ed154ae157b770aec6014b986f17474bc97 Mon Sep 17 00:00:00 2001 From: Judy Date: Sat, 11 May 2024 16:38:17 +0000 Subject: [PATCH 1/5] feat: implement data_pipeline and dataset for mf-bpr #12 --- configs/train_config.yaml | 18 +++--- data/datasets/mf_data_pipeline.py | 94 +++++++++++++++++++++++++++++++ data/datasets/mf_dataset.py | 49 ++++++++++++++++ poetry.lock | 66 ++++++++++++++-------- pyproject.toml | 2 +- train.py | 8 +++ 6 files changed, 205 insertions(+), 32 deletions(-) create mode 100644 data/datasets/mf_data_pipeline.py create mode 100644 data/datasets/mf_dataset.py diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 1828656..621f144 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -17,15 +17,17 @@ epochs: 5 batch_size: 32 lr: 0.001 optimizer: adamw -loss_name: bce +loss_name: bpr # pointwise # bce patience: 5 top_n: 10 # model config -model_name: CDAE -negative_sampling: True # False -neg_times: 5 # this works only when negative_sampling == True, if value is 5, the number of negative samples will be 5 times the number of positive samples by users -hidden_size: 64 -corruption_level: 0.6 -hidden_activation: sigmoid -output_activation: sigmoid +#model_name: CDAE +#negative_sampling: True # False +#neg_times: 5 # this works only when negative_sampling == True, if value is 5, the number of negative samples will be 5 times the number of positive samples by users +#hidden_size: 64 +#corruption_level: 0.6 +#hidden_activation: sigmoid +#output_activation: sigmoid + +model_name: MF diff --git a/data/datasets/mf_data_pipeline.py b/data/datasets/mf_data_pipeline.py new file mode 100644 index 0000000..3932f5a --- /dev/null +++ b/data/datasets/mf_data_pipeline.py @@ -0,0 +1,94 @@ +import os + +import numpy as np +import pandas as pd +from tqdm import tqdm +from sklearn.model_selection import train_test_split +from loguru import logger + +from .data_pipeline import DataPipeline + +class MFDataPipeline(DataPipeline): + + def __init__(self, cfg): + super().__init__(cfg) + self.num_items = None + self.num_users = None + + def split(self, df): + ''' + train_data: ((user_id, item_id, rating), ...) + ''' + logger.info(f'start random user split...') + train_df, valid_df, test_df = [], [], [] + + for _, user_df in df.groupby('user_id'): + if self.cfg.loss_name == 'pointwise': + user_train_df, user_test_df = train_test_split(user_df, test_size=.2, stratify=user_df['rating']) + user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25, stratify=user_train_df['rating']) + else: + user_train_df, user_test_df = train_test_split(user_df, test_size=.2) + user_train_df, user_valid_df = train_test_split(user_train_df, test_size=.25) + train_df.append(user_train_df) + valid_df.append(user_valid_df) + test_df.append(user_test_df) + + train_df = pd.concat(train_df) + valid_df = pd.concat(valid_df) + test_df = pd.concat(test_df) + + train_pos_df = train_df.groupby('user_id')['business_id'].agg(list) + train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id')['business_id'].agg(list) + test_pos_df = test_df.groupby('user_id')['business_id'].agg(list) + + train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') + valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') + test_data = pd.merge(test_df, test_pos_df, left_on='user_id', right_on='user_id', how='left') + + return train_data, valid_data, test_data + + def preprocess(self) -> pd.DataFrame: + ''' + output: pivot table (row: user, col: user-specific vector + item set, values: binary preference) + ''' + logger.info("start preprocessing...") + # load df + df = self._load_df() + # set num items and num users + self._set_num_items_and_num_users(df) + # negative sampling + if self.cfg.loss_name == 'pointwise': + df = self._negative_sampling(df, self.cfg.neg_times) + logger.info("done") + return df + + def _load_df(self): + logger.info("load df...") + return pd.read_csv(os.path.join(self.cfg.data_dir, 'yelp_interactions.tsv'), sep='\t', index_col=False) + + def _set_num_items_and_num_users(self, df): + self.num_items = df.business_id.nunique() + self.num_users = df.user_id.nunique() + + def _negative_sampling(self, df: pd.DataFrame, neg_times: 5) -> pd.DataFrame: + logger.info(f"negative sampling...") + logger.info(f"before neg sampling: {df.shape}") + all_items = df.business_id.unique() + + df['rating'] = 1 + neg_data = [] + for _, user_df in df.groupby('user_id'): + user_id = user_df.user_id.values[0] + pos_items = user_df.business_id.unique() + neg_items = [] + while len(neg_items) < len(pos_items)*neg_times: + neg_item = np.random.choice(all_items) + if (neg_item in pos_items) or (neg_item in neg_items): continue + neg_items.append(neg_item) + neg_data.extend([[user_id, neg_item, 0] for neg_item in neg_items]) + + df = pd.concat([df, pd.DataFrame(neg_data, columns=df.columns)], axis=0) + df = df.sample(frac=1).reset_index(drop=True) + logger.info(f"after neg sampling: {df.shape}") + logger.info(f"done...") + return df diff --git a/data/datasets/mf_dataset.py b/data/datasets/mf_dataset.py new file mode 100644 index 0000000..8f153e4 --- /dev/null +++ b/data/datasets/mf_dataset.py @@ -0,0 +1,49 @@ +import numpy as np + +import torch +from torch.utils.data import Dataset + +from loguru import logger + +class MFDataset(Dataset): + + def __init__(self, data, mode='train', num_items=None): + super().__init__() + self.data = data + self.mode = mode + self.num_items = num_items + + def __len__(self): + return len(self.data.keys()) + + def _negative_sampling(self, input_item, user_positives): + neg_item = np.random.randint(self.num_items) + while neg_item in user_positives: + neg_item = np.random.randint(self.num_items) + return neg_item + + def __getitem__(self, user_id): + input_mask = self.data[user_id]['input_mask'].astype('float32') + if self.mode == 'train': + pos_item = self.data[user_id]['business_id'].astype('float32') + user_pos_items = self.data[user_id]['pos_items'] + return { + 'user_id': user_id, + 'pos_item': input_item, + 'neg_item': self._negative_sampling(input_item, user_positives) + } + elif self.mode == 'valid': + pos_item = self.data[user_id]['business_id'].astype('float32') + user_pos_items = self.data[user_id]['pos_items'] + return { + 'user_id': user_id, + 'pos_item': input_item, + 'neg_item': self._negative_sampling(input_item, user_positives) + } + else: + user_pos_items = self.data[user_id]['pos_items'].astype('float32') + return { + 'user_id': user_id, + 'pos_items': pos_items, + } + diff --git a/poetry.lock b/poetry.lock index 75a01af..f5ae38d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1595,32 +1595,52 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "scikit-learn" -version = "1.4.2" +version = "1.4.0" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.9" files = [ - {file = "scikit-learn-1.4.2.tar.gz", hash = "sha256:daa1c471d95bad080c6e44b4946c9390a4842adc3082572c20e4f8884e39e959"}, - {file = "scikit_learn-1.4.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8539a41b3d6d1af82eb629f9c57f37428ff1481c1e34dddb3b9d7af8ede67ac5"}, - {file = "scikit_learn-1.4.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:68b8404841f944a4a1459b07198fa2edd41a82f189b44f3e1d55c104dbc2e40c"}, - {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81bf5d8bbe87643103334032dd82f7419bc8c8d02a763643a6b9a5c7288c5054"}, - {file = "scikit_learn-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f0ea5d0f693cb247a073d21a4123bdf4172e470e6d163c12b74cbb1536cf38"}, - {file = "scikit_learn-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:87440e2e188c87db80ea4023440923dccbd56fbc2d557b18ced00fef79da0727"}, - {file = "scikit_learn-1.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:45dee87ac5309bb82e3ea633955030df9bbcb8d2cdb30383c6cd483691c546cc"}, - {file = "scikit_learn-1.4.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1d0b25d9c651fd050555aadd57431b53d4cf664e749069da77f3d52c5ad14b3b"}, - {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0203c368058ab92efc6168a1507d388d41469c873e96ec220ca8e74079bf62e"}, - {file = "scikit_learn-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44c62f2b124848a28fd695db5bc4da019287abf390bfce602ddc8aa1ec186aae"}, - {file = "scikit_learn-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:5cd7b524115499b18b63f0c96f4224eb885564937a0b3477531b2b63ce331904"}, - {file = "scikit_learn-1.4.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:90378e1747949f90c8f385898fff35d73193dfcaec3dd75d6b542f90c4e89755"}, - {file = "scikit_learn-1.4.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ff4effe5a1d4e8fed260a83a163f7dbf4f6087b54528d8880bab1d1377bd78be"}, - {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:671e2f0c3f2c15409dae4f282a3a619601fa824d2c820e5b608d9d775f91780c"}, - {file = "scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36d0bc983336bbc1be22f9b686b50c964f593c8a9a913a792442af9bf4f5e68"}, - {file = "scikit_learn-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:d762070980c17ba3e9a4a1e043ba0518ce4c55152032f1af0ca6f39b376b5928"}, - {file = "scikit_learn-1.4.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9993d5e78a8148b1d0fdf5b15ed92452af5581734129998c26f481c46586d68"}, - {file = "scikit_learn-1.4.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:426d258fddac674fdf33f3cb2d54d26f49406e2599dbf9a32b4d1696091d4256"}, - {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5460a1a5b043ae5ae4596b3126a4ec33ccba1b51e7ca2c5d36dac2169f62ab1d"}, - {file = "scikit_learn-1.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49d64ef6cb8c093d883e5a36c4766548d974898d378e395ba41a806d0e824db8"}, - {file = "scikit_learn-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:c97a50b05c194be9146d61fe87dbf8eac62b203d9e87a3ccc6ae9aed2dfaf361"}, + {file = "scikit-learn-1.4.0.tar.gz", hash = "sha256:d4373c984eba20e393216edd51a3e3eede56cbe93d4247516d205643c3b93121"}, + {file = "scikit_learn-1.4.0-1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fce93a7473e2f4ee4cc280210968288d6a7d7ad8dc6fa7bb7892145e407085f9"}, + {file = "scikit_learn-1.4.0-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d77df3d1e15fc37a9329999979fa7868ba8655dbab21fe97fc7ddabac9e08cc7"}, + {file = "scikit_learn-1.4.0-1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2404659fedec40eeafa310cd14d613e564d13dbf8f3c752d31c095195ec05de6"}, + {file = "scikit_learn-1.4.0-1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e98632da8f6410e6fb6bf66937712c949b4010600ccd3f22a5388a83e610cc3c"}, + {file = "scikit_learn-1.4.0-1-cp310-cp310-win_amd64.whl", hash = "sha256:11b3b140f70fbc9f6a08884631ae8dd60a4bb2d7d6d1de92738ea42b740d8992"}, + {file = "scikit_learn-1.4.0-1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a8341eabdc754d5ab91641a7763243845e96b6d68e03e472531e88a4f1b09f21"}, + {file = "scikit_learn-1.4.0-1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1f6bce875ac2bb6b52514f67c185c564ccd299a05b65b7bab091a4c13dde12d"}, + {file = "scikit_learn-1.4.0-1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c408b46b2fd61952d519ea1af2f8f0a7a703e1433923ab1704c4131520b2083b"}, + {file = "scikit_learn-1.4.0-1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b465dd1dcd237b7b1dcd1a9048ccbf70a98c659474324fa708464c3a2533fad"}, + {file = "scikit_learn-1.4.0-1-cp311-cp311-win_amd64.whl", hash = "sha256:0db8e22c42f7980fe5eb22069b1f84c48966f3e0d23a01afde5999e3987a2501"}, + {file = "scikit_learn-1.4.0-1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7eef6ea2ed289af40e88c0be9f7704ca8b5de18508a06897c3fe21e0905efdf"}, + {file = "scikit_learn-1.4.0-1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:349669b01435bc4dbf25c6410b0892073befdaec52637d1a1d1ff53865dc8db3"}, + {file = "scikit_learn-1.4.0-1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d439c584e58434d0350701bd33f6c10b309e851fccaf41c121aed55f6851d8cf"}, + {file = "scikit_learn-1.4.0-1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0e2427d9ef46477625ab9b55c1882844fe6fc500f418c3f8e650200182457bc"}, + {file = "scikit_learn-1.4.0-1-cp312-cp312-win_amd64.whl", hash = "sha256:d3d75343940e7bf9b85c830c93d34039fa015eeb341c5c0b4cd7a90dadfe00d4"}, + {file = "scikit_learn-1.4.0-1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:76986d22e884ab062b1beecdd92379656e9d3789ecc1f9870923c178de55f9fe"}, + {file = "scikit_learn-1.4.0-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e22446ad89f1cb7657f0d849dcdc345b48e2d10afa3daf2925fdb740f85b714c"}, + {file = "scikit_learn-1.4.0-1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74812c9eabb265be69d738a8ea8d4884917a59637fcbf88a5f0e9020498bc6b3"}, + {file = "scikit_learn-1.4.0-1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aad2a63e0dd386b92da3270887a29b308af4d7c750d8c4995dfd9a4798691bcc"}, + {file = "scikit_learn-1.4.0-1-cp39-cp39-win_amd64.whl", hash = "sha256:53b9e29177897c37e2ff9d4ba6ca12fdb156e22523e463db05def303f5c72b5c"}, + {file = "scikit_learn-1.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cb8f044a8f5962613ce1feb4351d66f8d784bd072d36393582f351859b065f7d"}, + {file = "scikit_learn-1.4.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:a6372c90bbf302387792108379f1ec77719c1618d88496d0df30cb8e370b4661"}, + {file = "scikit_learn-1.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:785ce3c352bf697adfda357c3922c94517a9376002971bc5ea50896144bc8916"}, + {file = "scikit_learn-1.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0aba2a20d89936d6e72d95d05e3bf1db55bca5c5920926ad7b92c34f5e7d3bbe"}, + {file = "scikit_learn-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:2bac5d56b992f8f06816f2cd321eb86071c6f6d44bb4b1cb3d626525820d754b"}, + {file = "scikit_learn-1.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27ae4b0f1b2c77107c096a7e05b33458354107b47775428d1f11b23e30a73e8a"}, + {file = "scikit_learn-1.4.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5c5c62ffb52c3ffb755eb21fa74cc2cbf2c521bd53f5c04eaa10011dbecf5f80"}, + {file = "scikit_learn-1.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f0d2018ac6fa055dab65fe8a485967990d33c672d55bc254c56c35287b02fab"}, + {file = "scikit_learn-1.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a8918c415c4b4bf1d60c38d32958849a9191c2428ab35d30b78354085c7c7a"}, + {file = "scikit_learn-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:80a21de63275f8bcd7877b3e781679d2ff1eddfed515a599f95b2502a3283d42"}, + {file = "scikit_learn-1.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0f33bbafb310c26b81c4d41ecaebdbc1f63498a3f13461d50ed9a2e8f24d28e4"}, + {file = "scikit_learn-1.4.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:8b6ac1442ec714b4911e5aef8afd82c691b5c88b525ea58299d455acc4e8dcec"}, + {file = "scikit_learn-1.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05fc5915b716c6cc60a438c250108e9a9445b522975ed37e416d5ea4f9a63381"}, + {file = "scikit_learn-1.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:842b7d6989f3c574685e18da6f91223eb32301d0f93903dd399894250835a6f7"}, + {file = "scikit_learn-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:88bcb586fdff865372df1bc6be88bb7e6f9e0aa080dab9f54f5cac7eca8e2b6b"}, + {file = "scikit_learn-1.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f77674647dd31f56cb12ed13ed25b6ed43a056fffef051715022d2ebffd7a7d1"}, + {file = "scikit_learn-1.4.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:833999872e2920ce00f3a50839946bdac7539454e200eb6db54898a41f4bfd43"}, + {file = "scikit_learn-1.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:970ec697accaef10fb4f51763f3a7b1250f9f0553cf05514d0e94905322a0172"}, + {file = "scikit_learn-1.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:923d778f378ebacca2c672ab1740e5a413e437fb45ab45ab02578f8b689e5d43"}, + {file = "scikit_learn-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:1d041bc95006b545b59e458399e3175ab11ca7a03dc9a74a573ac891f5df1489"}, ] [package.dependencies] @@ -2018,4 +2038,4 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "5f85d32c2712b2a7bbb8eaa0a08e0883cfaedde4f7f9818313138a6a58f44246" +content-hash = "8a955b71c6ad9e9bc943034d78cdf5556da853cd6723d71176d48212764ffab0" diff --git a/pyproject.toml b/pyproject.toml index b36f045..e2b312d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ matplotlib = "^3.8.4" seaborn = "^0.13.2" hydra-core = "^1.3.2" loguru = "^0.7.2" -scikit-learn = "^1.4.2" +scikit-learn = "1.4.0" [tool.poetry.group.dev.dependencies] diff --git a/train.py b/train.py index 122c818..eb86144 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,9 @@ from omegaconf import OmegaConf from data.datasets.cdae_data_pipeline import CDAEDataPipeline +from data.datasets.mf_data_pipeline import MFDataPipeline from data.datasets.cdae_dataset import CDAEDataset +from data.datasets.mf_dataset import MFDataset from trainers.cdae_trainer import CDAETrainer from utils import set_seed @@ -18,6 +20,8 @@ def main(cfg: OmegaConf): if cfg.model_name in ('CDAE', ): data_pipeline = CDAEDataPipeline(cfg) + elif cfg.model_name == 'MF': + data_pipeline = MFDataPipeline(cfg) else: raise ValueError() @@ -28,6 +32,10 @@ def main(cfg: OmegaConf): train_dataset = CDAEDataset(train_data, 'train', neg_times=cfg.neg_times) valid_dataset = CDAEDataset(valid_data, 'valid', neg_times=cfg.neg_times) test_dataset = CDAEDataset(test_data, 'test') + elif cfg.model_name == 'MF': + train_dataset = MFDataset(train_data, 'train', num_items=data_pipeline.num_items) + valid_dataset = MFDataset(valid_data, 'valid', num_items=data_pipeline.num_items) + test_dataset = MFDataset(test_data, 'test') else: raise ValueError() From 80f9109bf310364457dd92a3d7fa7012d2a6b180 Mon Sep 17 00:00:00 2001 From: Judy Date: Sat, 11 May 2024 18:08:35 +0000 Subject: [PATCH 2/5] feat: implement mf model and bpr loss and train method in Trainer #12 --- configs/train_config.yaml | 1 + data/datasets/mf_data_pipeline.py | 14 ++++++++------ data/datasets/mf_dataset.py | 29 +++++++++++++++-------------- loss.py | 11 +++++++++++ models/mf.py | 19 +++++++++++++++++++ train.py | 5 +++++ 6 files changed, 59 insertions(+), 20 deletions(-) create mode 100644 models/mf.py diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 621f144..ce6dd5c 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -31,3 +31,4 @@ top_n: 10 #output_activation: sigmoid model_name: MF +embed_size: 64 diff --git a/data/datasets/mf_data_pipeline.py b/data/datasets/mf_data_pipeline.py index 3932f5a..b21955b 100644 --- a/data/datasets/mf_data_pipeline.py +++ b/data/datasets/mf_data_pipeline.py @@ -33,13 +33,15 @@ def split(self, df): valid_df.append(user_valid_df) test_df.append(user_test_df) - train_df = pd.concat(train_df) - valid_df = pd.concat(valid_df) - test_df = pd.concat(test_df) + train_df = pd.concat(train_df).reset_index() + valid_df = pd.concat(valid_df).reset_index() + test_df = pd.concat(test_df).reset_index() - train_pos_df = train_df.groupby('user_id')['business_id'].agg(list) - train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id')['business_id'].agg(list) - test_pos_df = test_df.groupby('user_id')['business_id'].agg(list) + train_pos_df = train_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) + train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) + test_pos_df = test_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) + + logger.info(train_pos_df) train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') diff --git a/data/datasets/mf_dataset.py b/data/datasets/mf_dataset.py index 8f153e4..4fee983 100644 --- a/data/datasets/mf_dataset.py +++ b/data/datasets/mf_dataset.py @@ -22,28 +22,29 @@ def _negative_sampling(self, input_item, user_positives): neg_item = np.random.randint(self.num_items) return neg_item - def __getitem__(self, user_id): - input_mask = self.data[user_id]['input_mask'].astype('float32') + def __getitem__(self, index): + data = self.data.iloc[index,:] + logger.info(data) if self.mode == 'train': - pos_item = self.data[user_id]['business_id'].astype('float32') - user_pos_items = self.data[user_id]['pos_items'] + pos_item = data['business_id'].astype('int64') + user_pos_items = data['pos_items'] return { - 'user_id': user_id, - 'pos_item': input_item, - 'neg_item': self._negative_sampling(input_item, user_positives) + 'user_id': data['user_id'].astype('int64'), + 'pos_item': pos_item, + 'neg_item': self._negative_sampling(pos_item, user_pos_items) } elif self.mode == 'valid': - pos_item = self.data[user_id]['business_id'].astype('float32') - user_pos_items = self.data[user_id]['pos_items'] + pos_item = data['business_id'].astype('int64') + user_pos_items = data['pos_items'] return { - 'user_id': user_id, - 'pos_item': input_item, - 'neg_item': self._negative_sampling(input_item, user_positives) + 'user_id': data['user_id'].astype('int64'), + 'pos_item': pos_item, + 'neg_item': self._negative_sampling(pos_item, user_pos_items) } else: - user_pos_items = self.data[user_id]['pos_items'].astype('float32') + user_pos_items = data['pos_items'].astype('int64') return { - 'user_id': user_id, + 'user_id': self.data[index]['user_id'].astype('int64'), 'pos_items': pos_items, } diff --git a/loss.py b/loss.py index 7c1cdc6..dd897c7 100644 --- a/loss.py +++ b/loss.py @@ -14,3 +14,14 @@ def forward(self, input: Tensor, target: Tensor, negative_mask: Tensor) -> Tenso loss_targets = (target.add(negative_mask)).nonzero(as_tuple=True) # compute loss only for nonzero indices return nn.functional.binary_cross_entropy(input[loss_targets], target[loss_targets], weight=self.weight, reduction=self.reduction) + + +class BPRLoss(nn.Module): + + def __init__(self): + super().__init__() + self.sigmoid = nn.Sigmoid() + + def forward(self, positive_preds, negative_preds): + difference = positive_preds - negative_preds + return -torch.log(self.sigmoid(difference)).mean() diff --git a/models/mf.py b/models/mf.py new file mode 100644 index 0000000..33fe4f2 --- /dev/null +++ b/models/mf.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + +from models.base_model import BaseModel + +class MatrixFactorization(BaseModel): + + def __init__(self, cfg, num_users, num_items): + super().__init__() + self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) + self.item_embedding = nn.Embedding(num_items, cfg.embed_size, dtype=torch.float32) + + def _init_weights(self): + for child in self.children(): + if isinstance(child, nn.Embedding): + nn.init.normal_(child.weights) + + def forward(self, user_id, item_id): + return torch.matmul(self.user_embedding(user_id), self.item_embedding(item_id).T) diff --git a/train.py b/train.py index eb86144..429df5c 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ from data.datasets.cdae_dataset import CDAEDataset from data.datasets.mf_dataset import MFDataset from trainers.cdae_trainer import CDAETrainer +from trainers.mf_trainer import MFTrainer from utils import set_seed import torch @@ -27,6 +28,7 @@ def main(cfg: OmegaConf): df = data_pipeline.preprocess() train_data, valid_data, test_data = data_pipeline.split(df) + logger.info(train_data) if cfg.model_name in ('CDAE', ): train_dataset = CDAEDataset(train_data, 'train', neg_times=cfg.neg_times) @@ -49,6 +51,9 @@ def main(cfg: OmegaConf): trainer.run(train_dataloader, valid_dataloader) trainer.load_best_model() trainer.evaluate(test_dataloader) + elif cfg.model_name in ('MF', ): + trainer = MFTrainer(cfg, data_pipeline.num_items, data_pipeline.num_users) + trainer.run(train_dataloader, valid_dataloader) if __name__ == '__main__': main() From 2ad1e11a3e955ffe6f778bfedd210d6037d975f0 Mon Sep 17 00:00:00 2001 From: Judy Date: Mon, 13 May 2024 15:45:12 +0000 Subject: [PATCH 3/5] fix: update validation dataset #12 --- configs/train_config.yaml | 3 ++- data/datasets/mf_data_pipeline.py | 10 +++++----- data/datasets/mf_dataset.py | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/train_config.yaml b/configs/train_config.yaml index ce6dd5c..8c2d1c5 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -13,13 +13,14 @@ notes: "..." # train config device: cuda # cpu -epochs: 5 +epochs: 10 batch_size: 32 lr: 0.001 optimizer: adamw loss_name: bpr # pointwise # bce patience: 5 top_n: 10 +weight_decay: 1e-5 # model config #model_name: CDAE diff --git a/data/datasets/mf_data_pipeline.py b/data/datasets/mf_data_pipeline.py index b21955b..dc7325d 100644 --- a/data/datasets/mf_data_pipeline.py +++ b/data/datasets/mf_data_pipeline.py @@ -38,16 +38,16 @@ def split(self, df): test_df = pd.concat(test_df).reset_index() train_pos_df = train_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) + valid_pos_df = valid_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) train_valid_pos_df = pd.concat([train_df, valid_df], axis=0).groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) test_pos_df = test_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) - logger.info(train_pos_df) - train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') - valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') - test_data = pd.merge(test_df, test_pos_df, left_on='user_id', right_on='user_id', how='left') + valid_data = pd.merge(valid_df, valid_pos_df, left_on='user_id', right_on='user_id', how='left') + valid_eval_data = pd.merge(valid_pos_df, train_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') + test_eval_data = pd.merge(test_pos_df, train_valid_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') - return train_data, valid_data, test_data + return train_data, valid_data, valid_eval_data, test_eval_data def preprocess(self) -> pd.DataFrame: ''' diff --git a/data/datasets/mf_dataset.py b/data/datasets/mf_dataset.py index 4fee983..6c81ed4 100644 --- a/data/datasets/mf_dataset.py +++ b/data/datasets/mf_dataset.py @@ -24,7 +24,6 @@ def _negative_sampling(self, input_item, user_positives): def __getitem__(self, index): data = self.data.iloc[index,:] - logger.info(data) if self.mode == 'train': pos_item = data['business_id'].astype('int64') user_pos_items = data['pos_items'] From 73f8ac663a1335248359ff02929f35b92215df8f Mon Sep 17 00:00:00 2001 From: Judy Date: Mon, 13 May 2024 15:46:58 +0000 Subject: [PATCH 4/5] feat: implement validate and evaluate methods in Trainer and add weight decay config #12 --- loss.py | 2 +- models/mf.py | 8 +- train.py | 11 +-- trainers/base_trainer.py | 6 +- trainers/mf_trainer.py | 156 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 172 insertions(+), 11 deletions(-) create mode 100644 trainers/mf_trainer.py diff --git a/loss.py b/loss.py index dd897c7..ca19e43 100644 --- a/loss.py +++ b/loss.py @@ -24,4 +24,4 @@ def __init__(self): def forward(self, positive_preds, negative_preds): difference = positive_preds - negative_preds - return -torch.log(self.sigmoid(difference)).mean() + return torch.mean(-torch.log(self.sigmoid(difference))) diff --git a/models/mf.py b/models/mf.py index 33fe4f2..512e4de 100644 --- a/models/mf.py +++ b/models/mf.py @@ -3,17 +3,21 @@ from models.base_model import BaseModel +from loguru import logger class MatrixFactorization(BaseModel): def __init__(self, cfg, num_users, num_items): super().__init__() self.user_embedding = nn.Embedding(num_users, cfg.embed_size, dtype=torch.float32) self.item_embedding = nn.Embedding(num_items, cfg.embed_size, dtype=torch.float32) + self._init_weights() def _init_weights(self): for child in self.children(): if isinstance(child, nn.Embedding): - nn.init.normal_(child.weights) + nn.init.normal_(child.weight) def forward(self, user_id, item_id): - return torch.matmul(self.user_embedding(user_id), self.item_embedding(item_id).T) + user_emb = self.user_embedding(user_id) + item_emb = self.item_embedding(item_id) + return torch.sum(user_emb * item_emb, dim=1) diff --git a/train.py b/train.py index 429df5c..74a0623 100644 --- a/train.py +++ b/train.py @@ -27,24 +27,24 @@ def main(cfg: OmegaConf): raise ValueError() df = data_pipeline.preprocess() - train_data, valid_data, test_data = data_pipeline.split(df) - logger.info(train_data) if cfg.model_name in ('CDAE', ): + train_data, valid_data, test_data = data_pipeline.split(df) train_dataset = CDAEDataset(train_data, 'train', neg_times=cfg.neg_times) valid_dataset = CDAEDataset(valid_data, 'valid', neg_times=cfg.neg_times) test_dataset = CDAEDataset(test_data, 'test') elif cfg.model_name == 'MF': + train_data, valid_data, valid_eval_data, test_eval_data = data_pipeline.split(df) train_dataset = MFDataset(train_data, 'train', num_items=data_pipeline.num_items) valid_dataset = MFDataset(valid_data, 'valid', num_items=data_pipeline.num_items) - test_dataset = MFDataset(test_data, 'test') else: raise ValueError() # pos_samples 를 이용한 negative sample을 수행해줘야 함 train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) - test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) + if cfg.model_name != 'MF': + test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) if cfg.model_name in ('CDAE', ): trainer = CDAETrainer(cfg, len(df.columns)-1, len(train_dataset)) @@ -53,7 +53,8 @@ def main(cfg: OmegaConf): trainer.evaluate(test_dataloader) elif cfg.model_name in ('MF', ): trainer = MFTrainer(cfg, data_pipeline.num_items, data_pipeline.num_users) - trainer.run(train_dataloader, valid_dataloader) + trainer.run(train_dataloader, valid_dataloader, valid_eval_data) + trainer.evaluate(test_eval_data, 'test') if __name__ == '__main__': main() diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index e683ecd..0177706 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -32,11 +32,11 @@ def _model(self, model_name: str) -> Module: logger.error(f"Not implemented model: {model_name}") raise NotImplementedError(f"Not implemented model: {model_name}") - def _optimizer(self, optimizer_name: str, model: Module, learning_rate: float) -> Optimizer: + def _optimizer(self, optimizer_name: str, model: Module, learning_rate: float, weight_decay: float=0) -> Optimizer: if optimizer_name.lower() == 'adam': - return Adam(model.parameters(), lr=learning_rate) + return Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) elif optimizer_name.lower() == 'adamw': - return AdamW(model.parameters(), lr=learning_rate) + return AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) else: logger.error(f"Optimizer Not Exists: {optimizer_name}") raise NotImplementedError(f"Optimizer Not Exists: {optimizer_name}") diff --git a/trainers/mf_trainer.py b/trainers/mf_trainer.py new file mode 100644 index 0000000..1caf5d1 --- /dev/null +++ b/trainers/mf_trainer.py @@ -0,0 +1,156 @@ +import numpy as np +import pandas as pd +from tqdm import tqdm + +import torch +import torch.nn as nn +from torch import Tensor +from torch.utils.data import DataLoader +from torch.optim import Optimizer + +from loguru import logger +from omegaconf.dictconfig import DictConfig + +from models.mf import MatrixFactorization +from .base_trainer import BaseTrainer +from metric import * +from loss import BPRLoss + +class MFTrainer(BaseTrainer): + def __init__(self, cfg: DictConfig, num_items: int, num_users: int) -> None: + super().__init__(cfg) + self.num_items = num_items + self.num_users = num_users + self.model = MatrixFactorization(self.cfg, num_users, num_items).to(self.device) + self.optimizer: Optimizer = self._optimizer(self.cfg.optimizer, self.model, self.cfg.lr, self.cfg.weight_decay) + self.loss = self._loss() + + def _loss(self): + return BPRLoss() + + def run(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, valid_eval_data: pd.DataFrame): + logger.info(f"[Trainer] run...") + + best_valid_loss: float = 1e+6 + best_valid_precision_at_k: float = .0 + best_valid_recall_at_k: float = .0 + best_valid_map_at_k: float = .0 + best_valid_ndcg_at_k: float = .0 + best_epoch: int = 0 + endurance: int = 0 + + # train + for epoch in range(self.cfg.epochs): + train_loss: float = self.train(train_dataloader) + valid_loss: float = self.validate(valid_dataloader) + (valid_precision_at_k, + valid_recall_at_k, + valid_map_at_k, + valid_ndcg_at_k) = self.evaluate(valid_eval_data, 'valid') + logger.info(f'''\n[Trainer] epoch: {epoch} > train loss: {train_loss:.4f} / + valid loss: {valid_loss:.4f} / + precision@K : {valid_precision_at_k:.4f} / + Recall@K: {valid_recall_at_k:.4f} / + MAP@K: {valid_map_at_k:.4f} / + NDCG@K: {valid_ndcg_at_k:.4f}''') + + # update model + if best_valid_loss > valid_loss: + logger.info(f"[Trainer] update best model...") + best_valid_loss = valid_loss + best_valid_precision_at_k = valid_precision_at_k + best_recall_k = valid_recall_at_k + best_valid_ndcg_at_k = valid_ndcg_at_k + best_valid_map_at_k = valid_map_at_k + best_epoch = epoch + endurance = 0 + + # TODO: add mlflow + + torch.save(self.model.state_dict(), f'{self.cfg.model_dir}/best_model.pt') + else: + endurance += 1 + if endurance > self.cfg.patience: + logger.info(f"[Trainer] ealry stopping...") + break + + + def train(self, train_dataloader: DataLoader) -> float: + self.model.train() + train_loss = 0 + for data in tqdm(train_dataloader): + user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ + data['neg_item'].to(self.device) + pos_pred = self.model(user_id, pos_item) + neg_pred = self.model(user_id, neg_item) + + self.optimizer.zero_grad() + loss = self.loss(pos_pred, neg_pred) + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + + return train_loss + + def validate(self, valid_dataloader: DataLoader) -> tuple[float]: + self.model.eval() + valid_loss = 0 + actual, predicted = [], [] + for data in tqdm(valid_dataloader): + user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ + data['neg_item'].to(self.device) + pos_pred = self.model(user_id, pos_item) + neg_pred = self.model(user_id, neg_item) + + loss = self.loss(pos_pred, neg_pred) + + valid_loss += loss.item() + + return valid_loss + + def evaluate(self, eval_data: pd.DataFrame, mode='valid') -> tuple: + + self.model.eval() + actual, predicted = [], [] + item_input = torch.tensor([item_id for item_id in range(self.num_items)]).to(self.device) + for user_id, row in tqdm(eval_data.iterrows(), total=eval_data.shape[0]): + pred = self.model(torch.tensor([user_id,]*self.num_items).to(self.device), item_input) + batch_predicted = \ + self._generate_top_k_recommendation(pred, row['mask_items']) + actual.append(row['pos_items']) + predicted.append(batch_predicted) + + test_precision_at_k = precision_at_k(actual, predicted, self.cfg.top_n) + test_recall_at_k = recall_at_k(actual, predicted, self.cfg.top_n) + test_map_at_k = map_at_k(actual, predicted, self.cfg.top_n) + test_ndcg_at_k = ndcg_at_k(actual, predicted, self.cfg.top_n) + + if mode == 'test': + logger.info(f'''\n[Trainer] Test > + precision@{self.cfg.top_n} : {test_precision_at_k:.4f} / + Recall@{self.cfg.top_n}: {test_recall_at_k:.4f} / + MAP@{self.cfg.top_n}: {test_map_at_k:.4f} / + NDCG@{self.cfg.top_n}: {test_ndcg_at_k:.4f}''') + + return (test_precision_at_k, + test_recall_at_k, + test_map_at_k, + test_ndcg_at_k) + + def _generate_top_k_recommendation(self, pred: Tensor, mask_items) -> tuple[list]: + + # mask to train items + pred = pred.cpu().detach().numpy() + pred[mask_items] = 0 + + # find the largest topK item indexes by user + topn_index = np.argpartition(pred, -self.cfg.top_n)[ -self.cfg.top_n:] + # take probs from predictions using above indexes + topn_prob = np.take_along_axis(pred, topn_index, axis=0) + # sort topK probs and find their indexes + sorted_indices = np.argsort(-topn_prob) + # apply sorted indexes to item indexes to get sorted topK item indexes by user + topn_index_sorted = np.take_along_axis(topn_index, sorted_indices, axis=0) + + return topn_index_sorted From 1e93fd320523b68d17d5e34d312ef131eb05af6a Mon Sep 17 00:00:00 2001 From: Judy Date: Tue, 14 May 2024 12:50:43 +0000 Subject: [PATCH 5/5] refactor: Remove unnecessary lines and correct __len__ method in MFDataset #12 --- configs/train_config.yaml | 4 ++-- data/datasets/mf_data_pipeline.py | 4 ++-- data/datasets/mf_dataset.py | 37 +++++++++---------------------- loss.py | 4 ++-- models/mf.py | 2 +- train.py | 5 +++-- trainers/base_trainer.py | 4 +++- trainers/mf_trainer.py | 6 ++--- 8 files changed, 26 insertions(+), 40 deletions(-) diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 8c2d1c5..e43d642 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -16,11 +16,11 @@ device: cuda # cpu epochs: 10 batch_size: 32 lr: 0.001 -optimizer: adamw +optimizer: sgd # adamw loss_name: bpr # pointwise # bce patience: 5 top_n: 10 -weight_decay: 1e-5 +weight_decay: 0 #1e-5 # model config #model_name: CDAE diff --git a/data/datasets/mf_data_pipeline.py b/data/datasets/mf_data_pipeline.py index dc7325d..fbd0580 100644 --- a/data/datasets/mf_data_pipeline.py +++ b/data/datasets/mf_data_pipeline.py @@ -17,7 +17,7 @@ def __init__(self, cfg): def split(self, df): ''' - train_data: ((user_id, item_id, rating), ...) + data: ((user_id, item_id, rating), ...) ''' logger.info(f'start random user split...') train_df, valid_df, test_df = [], [], [] @@ -43,7 +43,7 @@ def split(self, df): test_pos_df = test_df.groupby('user_id').agg({'business_id': [('pos_items', list)]}).droplevel(0, 1) train_data = pd.merge(train_df, train_pos_df, left_on='user_id', right_on='user_id', how='left') - valid_data = pd.merge(valid_df, valid_pos_df, left_on='user_id', right_on='user_id', how='left') + valid_data = pd.merge(valid_df, train_valid_pos_df, left_on='user_id', right_on='user_id', how='left') valid_eval_data = pd.merge(valid_pos_df, train_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') test_eval_data = pd.merge(test_pos_df, train_valid_pos_df.rename(columns={'pos_items': 'mask_items'}), left_on='user_id', right_on='user_id', how='left') diff --git a/data/datasets/mf_dataset.py b/data/datasets/mf_dataset.py index 6c81ed4..592ea5c 100644 --- a/data/datasets/mf_dataset.py +++ b/data/datasets/mf_dataset.py @@ -7,16 +7,15 @@ class MFDataset(Dataset): - def __init__(self, data, mode='train', num_items=None): + def __init__(self, data, num_items=None): super().__init__() self.data = data - self.mode = mode self.num_items = num_items def __len__(self): - return len(self.data.keys()) + return self.data.shape[0] - def _negative_sampling(self, input_item, user_positives): + def _negative_sampling(self, user_positives): neg_item = np.random.randint(self.num_items) while neg_item in user_positives: neg_item = np.random.randint(self.num_items) @@ -24,26 +23,10 @@ def _negative_sampling(self, input_item, user_positives): def __getitem__(self, index): data = self.data.iloc[index,:] - if self.mode == 'train': - pos_item = data['business_id'].astype('int64') - user_pos_items = data['pos_items'] - return { - 'user_id': data['user_id'].astype('int64'), - 'pos_item': pos_item, - 'neg_item': self._negative_sampling(pos_item, user_pos_items) - } - elif self.mode == 'valid': - pos_item = data['business_id'].astype('int64') - user_pos_items = data['pos_items'] - return { - 'user_id': data['user_id'].astype('int64'), - 'pos_item': pos_item, - 'neg_item': self._negative_sampling(pos_item, user_pos_items) - } - else: - user_pos_items = data['pos_items'].astype('int64') - return { - 'user_id': self.data[index]['user_id'].astype('int64'), - 'pos_items': pos_items, - } - + pos_item = data['business_id'].astype('int64') + user_pos_items = data['pos_items'] + return { + 'user_id': data['user_id'].astype('int64'), + 'pos_item': pos_item, + 'neg_item': self._negative_sampling(user_pos_items) + } diff --git a/loss.py b/loss.py index ca19e43..ba7ec0c 100644 --- a/loss.py +++ b/loss.py @@ -20,8 +20,8 @@ class BPRLoss(nn.Module): def __init__(self): super().__init__() - self.sigmoid = nn.Sigmoid() + self.logsigmoid = nn.LogSigmoid() def forward(self, positive_preds, negative_preds): difference = positive_preds - negative_preds - return torch.mean(-torch.log(self.sigmoid(difference))) + return torch.mean(-self.logsigmoid(difference)) diff --git a/models/mf.py b/models/mf.py index 512e4de..d9d945c 100644 --- a/models/mf.py +++ b/models/mf.py @@ -15,7 +15,7 @@ def __init__(self, cfg, num_users, num_items): def _init_weights(self): for child in self.children(): if isinstance(child, nn.Embedding): - nn.init.normal_(child.weight) + nn.init.xavier_uniform_(child.weight) def forward(self, user_id, item_id): user_emb = self.user_embedding(user_id) diff --git a/train.py b/train.py index 74a0623..1cd40a1 100644 --- a/train.py +++ b/train.py @@ -35,14 +35,15 @@ def main(cfg: OmegaConf): test_dataset = CDAEDataset(test_data, 'test') elif cfg.model_name == 'MF': train_data, valid_data, valid_eval_data, test_eval_data = data_pipeline.split(df) - train_dataset = MFDataset(train_data, 'train', num_items=data_pipeline.num_items) - valid_dataset = MFDataset(valid_data, 'valid', num_items=data_pipeline.num_items) + train_dataset = MFDataset(train_data, num_items=data_pipeline.num_items) + valid_dataset = MFDataset(valid_data, num_items=data_pipeline.num_items) else: raise ValueError() # pos_samples 를 이용한 negative sample을 수행해줘야 함 train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle) + if cfg.model_name != 'MF': test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index 0177706..ee5fe58 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -6,7 +6,7 @@ from torch import Tensor from torch.utils.data import DataLoader from torch.nn import Module, BCELoss -from torch.optim import Optimizer, Adam, AdamW +from torch.optim import Optimizer, Adam, AdamW, SGD from loguru import logger from omegaconf.dictconfig import DictConfig @@ -37,6 +37,8 @@ def _optimizer(self, optimizer_name: str, model: Module, learning_rate: float, w return Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) elif optimizer_name.lower() == 'adamw': return AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + elif optimizer_name.lower() == 'sgd': + return SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) else: logger.error(f"Optimizer Not Exists: {optimizer_name}") raise NotImplementedError(f"Optimizer Not Exists: {optimizer_name}") diff --git a/trainers/mf_trainer.py b/trainers/mf_trainer.py index 1caf5d1..3897caa 100644 --- a/trainers/mf_trainer.py +++ b/trainers/mf_trainer.py @@ -80,7 +80,7 @@ def train(self, train_dataloader: DataLoader) -> float: train_loss = 0 for data in tqdm(train_dataloader): user_id, pos_item, neg_item = data['user_id'].to(self.device), data['pos_item'].to(self.device), \ - data['neg_item'].to(self.device) + data['neg_item'].to(self.device) pos_pred = self.model(user_id, pos_item) neg_pred = self.model(user_id, neg_item) @@ -142,10 +142,10 @@ def _generate_top_k_recommendation(self, pred: Tensor, mask_items) -> tuple[list # mask to train items pred = pred.cpu().detach().numpy() - pred[mask_items] = 0 + pred[mask_items] = -3.40282e+38 # finfo(float32) # find the largest topK item indexes by user - topn_index = np.argpartition(pred, -self.cfg.top_n)[ -self.cfg.top_n:] + topn_index = np.argpartition(pred, -self.cfg.top_n)[-self.cfg.top_n:] # take probs from predictions using above indexes topn_prob = np.take_along_axis(pred, topn_index, axis=0) # sort topK probs and find their indexes