Skip to content

Commit

Permalink
complete mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Jul 4, 2024
1 parent 9fd2c87 commit 88bd055
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions core/model_finetuning/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,75 +114,80 @@ def project_main():
cfg.merge_from_list(args.opts)

cfg.data.name = args.data

cfg.data.device = args.device
cfg.decoder.device = args.device
cfg.device = args.device
cfg.train.epochs = args.epoch
cfg.embedder.type = args.embedder_type
evaluator_hit = Evaluator(name='ogbl-collab')
evaluator_mrr = Evaluator(name='ogbl-citation2')

cfg.out_dir = 'results/tfidf'
custom_set_out_dir(cfg, args.cfg_file, cfg.wandb.name_tag)
# torch.set_num_threads(20)
loggers = create_logger(args.repeat)
for run_id, seed, split_index in zip(*run_loop_settings(cfg, args)):
print(f'run id : {run_id}')
# Set configurations for each run TODO clean code here
train_dataset = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_train_dataset.pt')
train_labels = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_train_labels.pt')
val_dataset = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_val_dataset.pt')
val_labels = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_val_labels.pt')
test_dataset = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_test_dataset.pt')
test_labels = torch.load(f'./generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_test_labels.pt')
root = '/hkfs/work/workspace_haic/scratch/cc7738-TAGBench/TAPE/core/model_finetuning'
train_dataset = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_train_dataset.pt')
train_labels = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_train_labels.pt')
val_dataset = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_val_dataset.pt')
val_labels = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_val_labels.pt')
test_dataset = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_test_dataset.pt')
test_labels = torch.load(f'{root}/generated_dataset/{cfg.data.name}/{cfg.embedder.type}_{seed}_test_labels.pt')

clf = RidgeClassifier(tol=1e-2, solver="sparse_cg")
clf.fit(train_dataset, train_labels)

test_pred = clf.predict(test_dataset)
test_acc = sum(np.asarray(test_labels) == test_pred ) / len(test_labels)
test_acc = sum(np.asarray(test_labels) == test_pred ) / len(test_labels)
y_pos_pred, y_neg_pred = torch.tensor(test_pred[test_labels == 1]), torch.tensor(test_pred[test_labels == 0])
test_metrics = get_metric_score(evaluator_hit, evaluator_mrr, y_pos_pred, y_neg_pred)
test_metrics.update({'ACC': round(test_acc, 4)})

train_pred = clf.predict(train_dataset)
train_acc = sum(np.asarray(train_labels) == train_pred ) / len(train_labels)
train_acc = sum(np.asarray(train_labels) == train_pred ) / len(train_labels)
y_pos_pred, y_neg_pred = torch.tensor(train_pred[train_labels == 1]), torch.tensor(train_pred[train_labels == 0])
train_metrics = get_metric_score(evaluator_hit, evaluator_mrr, y_pos_pred, y_neg_pred)
train_metrics.update({'ACC': round(train_acc, 4)})

val_pred = clf.predict(val_dataset)
val_acc = sum(np.asarray(val_labels) == val_pred ) / len(val_labels)
val_acc = sum(np.asarray(val_labels) == val_pred ) / len(val_labels)
y_pos_pred, y_neg_pred = torch.tensor(val_pred[val_labels == 1]), torch.tensor(val_pred[val_labels == 0])
val_metrics = get_metric_score(evaluator_hit, evaluator_mrr, y_pos_pred, y_neg_pred)
val_metrics.update({'ACC': round(val_acc, 4)})

print(f'Accuracy: {test_acc:.4f}')
print(f'metrics : {test_metrics}')

results_rank = {
key: (test_metrics[key], train_metrics[key], val_metrics[key])
for key in test_metrics.keys()
}

for key, result in results_rank.items():
loggers[key].add_result(run_id, result)
st()

for key in results_rank.keys():
print(loggers[key].calc_run_stats(run_id))
for key in results_rank:
print(key, loggers[key].results)

for key in results_rank.keys():
print(loggers[key].calc_all_stats())


root = os.path.join(FILE_PATH, cfg.out_dir)
acc_file = os.path.join(root, f'{cfg.data.name}_wb_acc_mrr.csv')
acc_file = os.path.join(root, f'{cfg.data.name}_lm_mrr.csv')

results_dict = {key: loggers[key].calc_all_stats() for key in results_rank.keys()}
run_result = {}
for key in loggers.keys():
print(key)
_, _, _, test_bvalid, _, _ = loggers[key].calc_all_stats(True)
run_result[key] = test_bvalid

os.makedirs(root, exist_ok=True)
name_tag = cfg.wandb.name_tag = f'{cfg.data.name}_run{id}_{args.model}'
mvari_str2csv(name_tag, results_dict, acc_file)
name_tag = cfg.wandb.name_tag = f'{cfg.data.name}_run{run_id}_{args.embedder_type}'
mvari_str2csv(name_tag, run_result, acc_file)
# clf = MLPClassifier(random_state=1, max_iter=300).fit(train_dataset, train_labels)
# test_proba = clf.predict_proba(test_dataset)
# test_pred = clf.predict(test_dataset)
Expand Down

0 comments on commit 88bd055

Please sign in to comment.