-
Notifications
You must be signed in to change notification settings - Fork 5
/
main.py
57 lines (45 loc) · 1.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import numpy as np
import argparse
import torch
from Params import Params
from dataset import Dataset
from Logger import Logger
from evaluation import Evaluator
from Trainer import Trainer
from ModelBuilder import build_model
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='slim')
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--save_dir', type=str, default='./saves')
parser.add_argument('--conf_dir', type=str, default='./conf')
parser.add_argument('--seed', type=int, default=428)
conf = parser.parse_args()
model_conf = Params(os.path.join(conf.conf_dir, conf.model.lower() + '.json'))
model_conf.update_dict('exp_conf', conf.__dict__)
np.random.seed(conf.seed)
torch.random.manual_seed(conf.seed)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dataset = Dataset(
data_dir=conf.data_dir,
data_name=model_conf.data_name,
train_ratio=model_conf.train_ratio,
device=device
)
log_dir = os.path.join('saves', conf.model)
logger = Logger(log_dir)
model_conf.save(os.path.join(logger.log_dir, 'config.json'))
eval_pos, eval_target = dataset.eval_data()
item_popularity = dataset.item_popularity
evaluator = Evaluator(eval_pos, eval_target, item_popularity, model_conf.top_k)
model = build_model(conf.model, model_conf, dataset.num_users, dataset.num_items, device)
logger.info(model_conf)
logger.info(dataset)
trainer = Trainer(
dataset=dataset,
model=model,
evaluator=evaluator,
logger=logger,
conf=model_conf
)
trainer.train()