diff --git a/core/data_utils/load.py b/core/data_utils/load.py index 16095297b2..f8b573706b 100644 --- a/core/data_utils/load.py +++ b/core/data_utils/load.py @@ -8,16 +8,22 @@ load_tag_product, load_tag_ogbn_arxiv, load_tag_product, - load_tag_arxiv23) + load_tag_arxiv23, + load_tag_citeseer, + load_tag_citationv8) from data_utils.load_data_lp import (load_taglp_arxiv2023, load_taglp_cora, load_taglp_pubmed, load_taglp_product, - load_taglp_ogbn_arxiv) + load_taglp_ogbn_arxiv, + load_taglp_citeseer, + load_taglp_citationv8) from data_utils.load_data_lp import (load_graph_cora, load_graph_arxiv23, load_graph_ogbn_arxiv, - load_graph_pubmed) + load_graph_pubmed, + load_graph_citeseer, + load_graph_citationv8) # TODO standarize the input and output load_data_nc = { @@ -26,6 +32,8 @@ 'arxiv_2023': load_tag_arxiv23, 'ogbn-arxiv': load_tag_ogbn_arxiv, 'ogbn-products': load_tag_product, + 'citeseer': load_tag_citeseer, + 'citationv8': load_tag_citationv8, } load_data_lp = { @@ -34,6 +42,8 @@ 'arxiv_2023': load_taglp_arxiv2023, 'ogbn-arxiv': load_taglp_ogbn_arxiv, 'ogbn-products': load_taglp_product, + 'citeseer': load_taglp_citeseer, + 'citationv8': load_taglp_citationv8, } load_graph_lp = { @@ -41,6 +51,8 @@ 'pubmed': load_graph_pubmed, 'arxiv_2023': load_graph_arxiv23, 'ogbn-arxiv': load_graph_ogbn_arxiv, + 'citeseer': load_graph_citeseer, + 'citationv8': load_graph_citationv8, } diff --git a/core/data_utils/load_data_lp.py b/core/data_utils/load_data_lp.py index 57b5c13286..0f6c6376d0 100644 --- a/core/data_utils/load_data_lp.py +++ b/core/data_utils/load_data_lp.py @@ -21,7 +21,8 @@ load_tag_arxiv23, load_graph_cora, load_graph_pubmed, \ load_graph_arxiv23, load_graph_ogbn_arxiv, load_text_cora, \ load_text_pubmed, load_text_arxiv23, load_text_ogbn_arxiv, \ - load_text_product + load_text_product, load_text_citeseer, load_text_citationv8, \ + load_graph_citeseer, load_graph_citationv8 from graphgps.utility.utils import get_git_repo_root_path, config_device, init_cfg_test from graphgps.utility.utils import time_logger @@ -146,21 +147,53 @@ def load_taglp_pubmed(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: ) return splits, text, data +def load_taglp_citeseer(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: + # add one default argument + + data = load_graph_citeseer() + text = load_text_citeseer() + undirected = data.is_directed() + + splits = get_edge_split(data, + undirected, + cfg.device, + cfg.split_index[1], + cfg.split_index[2], + cfg.include_negatives, + cfg.split_labels + ) + return splits, text, data + +def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: + # add one default argument + + data = load_graph_citationv8() + text = load_text_citationv8() + undirected = data.is_directed() + + splits = get_edge_split(data, + undirected, + cfg.device, + cfg.split_index[1], + cfg.split_index[2], + cfg.include_negatives, + cfg.split_labels + ) + return splits, text, data + # TEST CODE if __name__ == '__main__': args = init_cfg_test() print(args) - data, text, __ = load_taglp_arxiv2023(args.data) + '''data, text, __ = load_taglp_arxiv2023(args.data) print(data) print(type(text)) data, text = load_taglp_cora(args.data) print(data) print(type(text)) - data, text = load_taglp_ogbn_arxiv(args.data) - print(data) - print(type(text)) + data, text = load_taglp_product(args.data) print(data) @@ -168,4 +201,16 @@ def load_taglp_pubmed(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: data, text = load_taglp_pubmed(args.data) print(data) + print(type(text))''' + + splits, text, data = load_taglp_citeseer(args.data) + print(data) + print(type(text)) + + splits, text, data = load_taglp_citationv8(args.data) + print(data) + print(type(text)) + + splits, text, data = load_taglp_ogbn_arxiv(args.data) + print(data) print(type(text)) \ No newline at end of file diff --git a/core/data_utils/load_data_nc.py b/core/data_utils/load_data_nc.py index 6bddfb9beb..39a1bdd75c 100644 --- a/core/data_utils/load_data_nc.py +++ b/core/data_utils/load_data_nc.py @@ -1,6 +1,6 @@ import os, sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - +import dgl import torch import pandas as pd import numpy as np @@ -419,7 +419,7 @@ def load_graph_ogbn_arxiv(use_mask): def load_tag_ogbn_arxiv() -> List[str]: - graph = load_graph_ogbn_arxiv() + graph = load_graph_ogbn_arxiv(False) text = load_text_ogbn_arxiv() return graph, text @@ -436,15 +436,69 @@ def load_tag_product() -> Tuple[Data, List[str]]: return data, text +def load_graph_citationv8() -> Data: + graph = dgl.load_graphs(FILE_PATH + 'core/dataset/citationv8/Citation-2015.pt')[0][0] + graph = dgl.to_bidirected(graph) + from torch_geometric.utils import from_dgl + graph = from_dgl(graph) + return graph + + +def load_text_citationv8() -> List[str]: + df = pd.read_csv(FILE_PATH + 'core/dataset/citationv8_orig/Citation-2015.csv') + return [ + f'Text: {ti}\n' + for ti in zip(df['text']) + ] + + +def load_tag_citationv8() -> Tuple[Data, List[str]]: + graph = load_graph_citationv8() + text = None + train_id, val_id, test_id, train_mask, val_mask, test_mask = get_node_mask(graph.num_nodes) + graph.train_id = train_id + graph.val_id = val_id + graph.test_id = test_id + graph.train_mask = train_mask + graph.val_mask = val_mask + graph.test_mask = test_mask + return graph, text + + +def load_graph_citeseer() -> Data: + # load data + data_name = 'CiteSeer' + dataset = Planetoid('./generated_dataset', data_name, transform=T.NormalizeFeatures()) + data = dataset[0] + return data + + +def load_text_citeseer() -> List[str]: + + return None + + +def load_tag_citeseer() -> Tuple[Data, List[str]]: + graph = load_graph_citeseer() + text = load_text_citeseer() + return graph, text + + # Test code if __name__ == '__main__': + graph = load_graph_citeseer() + print(type(graph)) + graph, text = load_tag_citeseer() + print(type(text)) + + graph = load_graph_arxiv23() # print(type(graph)) graph, text = load_tag_arxiv23() print(type(graph)) print(type(text)) - graph, _ = load_graph_cora(True) + '''graph, _ = load_graph_cora(True) # print(type(graph)) graph, text = load_tag_cora() print(type(graph)) @@ -461,4 +515,9 @@ def load_tag_product() -> Tuple[Data, List[str]]: graph = load_graph_pubmed() graph, text = load_tag_pubmed() print(type(graph)) + print(type(text))''' + + graph = load_graph_citationv8() + print(type(graph)) + graph, text = load_tag_citationv8() print(type(text)) \ No newline at end of file diff --git a/core/gcns/nbfnet_tune.py b/core/gcns/nbfnet_tune.py index 22a9517011..4b66024b89 100644 --- a/core/gcns/nbfnet_tune.py +++ b/core/gcns/nbfnet_tune.py @@ -106,16 +106,15 @@ def parse_args() -> argparse.Namespace: f"\n Valid: {2 * splits['train']['pos_edge_label'].shape[0]} samples," f"\n Test: {2 * splits['test']['pos_edge_label'].shape[0]} samples") dump_cfg(cfg) - hyperparameter_search = {'hidden_channels': [32, 64, 128, 256], 'num_layers': [3, 4, 5, 6], + hyperparameter_search = {'hidden_channels': [32, 64, 128, 256], "batch_size": [64, 128, 256, 512, 1024], "lr": [0.01, 0.001, 0.0001]} print_logger.info(f"hypersearch space: {hyperparameter_search}") - for hidden_channels, num_layers, batch_size, lr in tqdm(itertools.product(*hyperparameter_search.values())): + for hidden_channels, batch_size, lr in tqdm(itertools.product(*hyperparameter_search.values())): cfg.model.hidden_channels = hidden_channels cfg.train.batch_size = batch_size cfg.optimizer.lr = lr - cfg.model.num_layers = num_layers print_logger.info( - f"hidden_channels: {hidden_channels}, num_layers: {num_layers}, batch_size: {batch_size}, lr: {lr}") + f"hidden_channels: {hidden_channels}, batch_size: {batch_size}, lr: {lr}") start_time = time.time() model = NBFNet(cfg.model.in_channels, [cfg.model.hidden_channels] * cfg.model.num_layers, num_relation = 1) @@ -159,7 +158,7 @@ def parse_args() -> argparse.Namespace: run_result[key] = test_bvalid run_result.update( - {'hidden_channels': hidden_channels,' num_layers': num_layers, 'batch_size': batch_size,'lr': lr, + {'hidden_channels': hidden_channels, 'batch_size': batch_size,'lr': lr, }) print_logger.info(run_result) diff --git a/core/gcns/neognn_main.py b/core/gcns/neognn_main.py index 7fc0dc6738..babe47972b 100644 --- a/core/gcns/neognn_main.py +++ b/core/gcns/neognn_main.py @@ -3,6 +3,8 @@ from torch_sparse import SparseTensor +from torch_geometric.graphgym import params_count + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import argparse import time @@ -22,29 +24,39 @@ from graphgps.network.neognn import NeoGNN, LinkPredictor from data_utils.load import load_data_lp from graphgps.train.neognn_train import Trainer_NeoGNN +from graphgps.network.ncn import predictor_dict + def parse_args() -> argparse.Namespace: r"""Parses the command line arguments.""" parser = argparse.ArgumentParser(description='GraphGym') - parser.add_argument('--cfg', dest='cfg_file', type=str, required=False, - default='core/yamls/cora/gcns/seal.yaml', + default='core/yamls/cora/gcns/ncn.yaml', help='The configuration file path.') + parser.add_argument('--sweep', dest='sweep_file', type=str, required=False, - default='core/yamls/cora/gcns/gae_sp1.yaml', + default='core/yamls/cora/gcns/ncn.yaml', help='The configuration file path.') - parser.add_argument('--data', dest='data', type=str, required=False, - default='cora', + parser.add_argument('--data', dest='data', type=str, required=True, + default='pubmed', help='data name') - parser.add_argument('--repeat', type=int, default=2, help='The number of repeated jobs.') + parser.add_argument('--batch_size', dest='bs', type=int, required=False, + default=2**15, + help='data name') + parser.add_argument('--device', dest='device', required=True, + help='device id') + parser.add_argument('--epochs', dest='epoch', type=int, required=True, + default=400, + help='data name') + parser.add_argument('--wandb', dest='wandb', required=False, + help='data name') parser.add_argument('--mark_done', action='store_true', help='Mark yaml as done after a job has finished.') parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='See graphgym/config.py for remaining options.') - return parser.parse_args() def ngnn_dataset(data, splits): @@ -69,18 +81,27 @@ def ngnn_dataset(data, splits): cfg = set_cfg(FILE_PATH, args.cfg_file) cfg.merge_from_list(args.opts) + cfg.data.name = args.data + + cfg.data.device = args.device + cfg.model.device = args.device + cfg.device = args.device + cfg.train.epochs = args.epoch + torch.set_num_threads(cfg.num_threads) batch_sizes = [cfg.train.batch_size] best_acc = 0 best_params = {} loggers = create_logger(args.repeat) + cfg.device = args.device for batch_size in batch_sizes: for run_id, seed, split_index in zip( *run_loop_settings(cfg, args)): custom_set_run_dir(cfg, run_id) set_printing(cfg) + print_logger = set_printing(cfg) cfg.seed = seed cfg.run_id = run_id seed_everything(cfg.seed) @@ -113,7 +134,7 @@ def ngnn_dataset(data, splits): run_id, args.repeat, loggers, - print_logger=None, + print_logger=print_logger, batch_size=batch_size) start = time.time() @@ -130,3 +151,8 @@ def ngnn_dataset(data, splits): result_dict[key] = valid_test trainer.save_result(result_dict) + + cfg.model.params = params_count(model) + print_logger.info(f'Num parameters: {cfg.model.params}') + trainer.finalize() + print_logger.info(f"Inference time: {trainer.run_result['eval_time']}") diff --git a/core/gcns/subgraph_sketching_tune.py b/core/gcns/subgraph_sketching_tune.py index 43e11ca05c..362fa772c2 100644 --- a/core/gcns/subgraph_sketching_tune.py +++ b/core/gcns/subgraph_sketching_tune.py @@ -159,23 +159,21 @@ def parse_args() -> argparse.Namespace: f"\n Valid: {2 * splits['train']['pos_edge_label'].shape[0]} samples," f"\n Test: {2 * splits['test']['pos_edge_label'].shape[0]} samples") dump_cfg(cfg) - hyperparameter_search = {'hidden_channels': [128, 256, 512, 1024], 'num_layers': [2, 3], - "batch_size": [512, 1024], "lr": [0.01, 0.001, 0.0001], + hyperparameter_search = {'hidden_channels': [128, 256, 512, 1024], "batch_size": [512, 1024], "lr": [0.01, 0.001, 0.0001], 'max_hash_hops': [2, 3], 'label_dropout': [0.1, 0.3, 0.5], 'feature_dropout': [0.1, 0.3, 0.5], 'sign_dropout': [0.1, 0.3, 0.5],} print_logger.info(f"hypersearch space: {hyperparameter_search}") - for hidden_channels, num_layers, batch_size, lr, max_hash_hops, label_dropout, feature_dropout, sign_dropout in tqdm( + for hidden_channels, batch_size, lr, max_hash_hops, label_dropout, feature_dropout, sign_dropout in tqdm( itertools.product(*hyperparameter_search.values())): cfg.model.hidden_channels = hidden_channels cfg.train.batch_size = batch_size cfg.optimizer.lr = lr - cfg.model.num_layers = num_layers cfg.model.max_hash_hops = max_hash_hops cfg.model.label_dropout = label_dropout cfg.model.feature_dropout = feature_dropout cfg.model.sign_dropout = sign_dropout print_logger.info( - f"hidden_channels: {hidden_channels}, num_layers: {num_layers}, batch_size: {batch_size}, lr: {lr}") + f"hidden_channels: {hidden_channels}, batch_size: {batch_size}, lr: {lr}") start_time = time.time() if cfg.model.type == 'BUDDY': splits['train'] = hash_dataset(splits['train']) @@ -225,7 +223,7 @@ def parse_args() -> argparse.Namespace: run_result[key] = test_bvalid run_result.update( - {'hidden_channels': hidden_channels,' num_layers': num_layers, 'batch_size': batch_size,'lr': lr, + {'hidden_channels': hidden_channels, 'batch_size': batch_size,'lr': lr, 'max_hash_hops': max_hash_hops, 'label_dropout': label_dropout, 'feature_dropout': feature_dropout, 'sign_dropout': sign_dropout }) diff --git a/core/graphgps/network/nbfnet.py b/core/graphgps/network/nbfnet.py index c2ae72b519..eb3f942438 100644 --- a/core/graphgps/network/nbfnet.py +++ b/core/graphgps/network/nbfnet.py @@ -48,6 +48,31 @@ def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult" mlp.append(nn.Linear(feature_dim, 1)) self.mlp = nn.Sequential(*mlp) + def remove_easy_edges(self, data, h_index, t_index, r_index=None): + # we remove training edges (we need to predict them at training time) from the edge index + # think of it as a dynamic edge dropout + h_index_ext = torch.cat([h_index, t_index], dim=-1) + t_index_ext = torch.cat([t_index, h_index], dim=-1) + r_index_ext = torch.cat([r_index, r_index + self.num_relation // 2], dim=-1) + if self.remove_one_hop: + # we remove all existing immediate edges between heads and tails in the batch + edge_index = data.edge_index + easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1) + index = edge_match(edge_index, easy_edge)[0] + mask = ~index_to_mask(index, data.num_edges) + else: + # we remove existing immediate edges between heads and tails in the batch with the given relation + edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)]) + # note that here we add relation types r_index_ext to the matching query + easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1) + index = edge_match(edge_index, easy_edge)[0] + mask = ~index_to_mask(index, data.num_edges) + + data = copy.copy(data) + data.edge_index = data.edge_index[:, mask] + data.edge_type = data.edge_type[mask] + return data + def negative_sample_to_tail(self, h_index, t_index, r_index): # convert p(h | t, r) to p(t' | h', r') # h' = t, r' = r^{-1}, t' = h @@ -102,6 +127,12 @@ def bellmanford(self, data, h_index, r_index, separate_grad=False): def forward(self, data, batch): h_index, t_index, r_index = batch.unbind(-1) + if self.training: + # Edge dropout in the training mode + # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types + # to make NBFNet iteration learn non-trivial paths + data = self.remove_easy_edges(data, h_index, t_index, r_index) + shape = h_index.shape # turn all triples in a batch into a tail prediction mode h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) diff --git a/core/graphgps/train/neognn_train.py b/core/graphgps/train/neognn_train.py index 84939fd40f..1c7eef7034 100644 --- a/core/graphgps/train/neognn_train.py +++ b/core/graphgps/train/neognn_train.py @@ -27,6 +27,9 @@ from typing import Dict, Tuple from graphgps.train.opt_train import (Trainer) from graphgps.utility.ncn import PermIterator +from torch.utils.tensorboard import SummaryWriter + +writer = SummaryWriter() class Trainer_NeoGNN(Trainer): @@ -72,6 +75,9 @@ def __init__(self, self.evaluator_hit = Evaluator(name='ogbl-collab') self.evaluator_mrr = Evaluator(name='ogbl-citation2') + self.evaluator_hit = Evaluator(name='ogbl-collab') + self.evaluator_mrr = Evaluator(name='ogbl-citation2') + self.run = run self.repeat = repeat self.results_rank = {} @@ -79,6 +85,20 @@ def __init__(self, self.name_tag = cfg.wandb.name_tag self.run_result = {} + self.tensorboard_writer = writer + self.out_dir = cfg.out_dir + self.run_dir = cfg.run_dir + + report_step = { + 'cora': 1, + 'pubmed': 1, + 'arxiv_2023': 1, + 'ogbn-arxiv': 1, + 'ogbn-products': 1, + } + + self.report_step = report_step[cfg.data.name] + def _train_neognn(self): self.model.train() total_loss = 0 @@ -113,8 +133,15 @@ def _train_neognn(self): pos_loss = -torch.log(pos_out_feat_large + 1e-15).mean() neg_loss = -torch.log(1 - neg_out_feat_large + 1e-15).mean() loss2 = pos_loss + neg_loss - pos_loss = -torch.log(pos_out + 1e-15).mean() - neg_loss = -torch.log(1 - neg_out + 1e-15).mean() + eps = 1e-15 # Small constant to avoid log(0) + + # Clamp the outputs to avoid log(0) or log(1 - 1) issues + pos_out = torch.clamp(pos_out, min=eps, max=1 - eps) + neg_out = torch.clamp(neg_out, min=eps, max=1 - eps) + + # Calculate the losses + pos_loss = -torch.log(pos_out).mean() + neg_loss = -torch.log(1 - neg_out).mean() loss3 = pos_loss + neg_loss loss = loss1 + loss2 + loss3 loss.backward() @@ -127,7 +154,6 @@ def _train_neognn(self): total_loss += loss.item() * num_examples total_examples += num_examples count += 1 - return total_loss / total_examples def train(self): @@ -137,75 +163,79 @@ def train(self): if torch.isnan(torch.tensor(loss)): print('Loss is nan') break - if epoch % 100 == 0: - results_rank = self.merge_result_rank() - print(results_rank) - - for key, result in results_rank.items(): - print(key, result) + if epoch % int(self.report_step) == 0: + self.results_rank = self.merge_result_rank() + for key, result in self.results_rank.items(): self.loggers[key].add_result(self.run, result) - print(self.run) - print(result) + self.print_logger.info( + f'Epoch: {epoch:03d}, Loss_train: {loss:.4f}, AUC: {self.results_rank["AUC"][0]:.4f}, AP: {self.results_rank["AP"][0]:.4f}, MRR: {self.results_rank["MRR"][0]:.4f}, Hit@10 {self.results_rank["Hits@10"][0]:.4f}') + self.print_logger.info( + f'Epoch: {epoch:03d}, Loss_valid: {loss:.4f}, AUC: {self.results_rank["AUC"][1]:.4f}, AP: {self.results_rank["AP"][1]:.4f}, MRR: {self.results_rank["MRR"][1]:.4f}, Hit@10 {self.results_rank["Hits@10"][1]:.4f}') + self.print_logger.info( + f'Epoch: {epoch:03d}, Loss_test: {loss:.4f}, AUC: {self.results_rank["AUC"][2]:.4f}, AP: {self.results_rank["AP"][2]:.4f}, MRR: {self.results_rank["MRR"][2]:.4f}, Hit@10 {self.results_rank["Hits@10"][2]:.4f}') - return best_auc, best_hits + self.tensorboard_writer.add_scalar(f"Metrics/Train/{key}", result[0], epoch) + self.tensorboard_writer.add_scalar(f"Metrics/Valid/{key}", result[1], epoch) + self.tensorboard_writer.add_scalar(f"Metrics/Test/{key}", result[2], epoch) - @torch.no_grad() - def _test(self, data: Data): - self.model.eval() - self.predictor.eval() - pos_edge = data['pos_edge_label_index'].to(self.device) - neg_edge = data['neg_edge_label_index'].to(self.device) - pos_pred,_,_,_ = self.model(pos_edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) - pos_pred = pos_pred.squeeze() - neg_pred,_,_,_ = self.model(neg_edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) - neg_pred = neg_pred.squeeze() + train_hits, valid_hits, test_hits = result + self.print_logger.info( + f'Run: {self.run + 1:02d}, Key: {key}, ' + f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {100 * train_hits:.2f}, Valid: {100 * valid_hits:.2f}, Test: {100 * test_hits:.2f}%') - y_pred = torch.cat([pos_pred, neg_pred], dim=0) - hard_thres = (y_pred.max() + y_pred.min()) / 2 - pos_y = torch.ones(pos_edge.size(1)) - neg_y = torch.zeros(neg_edge.size(1)) - y_true = torch.cat([pos_y, neg_y], dim=0) - '''self.save_pred(y_pred, y_true, data)''' + self.print_logger.info('---') - y_pred = torch.where(y_pred >= hard_thres, torch.tensor(1), torch.tensor(0)) - - y_true = y_true.clone().detach() - y_pred = y_pred.clone().detach() + return best_auc, best_hits - y_pred, y_true = y_pred.detach().cpu().numpy(), y_true.detach().cpu().numpy() - fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1) - return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred), auc(fpr, tpr) @torch.no_grad() def _evaluate(self, eval_data: Data): - self.model.eval() self.predictor.eval() + pos_edge = eval_data['pos_edge_label_index'].to(self.device) neg_edge = eval_data['neg_edge_label_index'].to(self.device) - pos_pred, _, _, _ = self.model(pos_edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) - pos_pred = pos_pred.squeeze() - neg_pred, _, _, _ = self.model(neg_edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) - neg_pred = neg_pred.squeeze() + + pos_pred_list, neg_pred_list = [], [] + + with torch.no_grad(): + for perm in DataLoader(range(pos_edge.size(1)), self.batch_size, shuffle=True): + # Positive edge prediction + edge = pos_edge[:, perm].to(self.device) + pos_pred, _, _, _ = self.model(edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) + pos_pred = pos_pred.squeeze() + pos_pred_list.append(pos_pred.cpu()) + + # Negative edge prediction + edge = neg_edge[:, perm].to(self.device) + neg_pred, _, _, _ = self.model(edge, self.data, self.data.A, self.predictor, emb=self.data.emb.weight) + neg_pred = neg_pred.squeeze() + neg_pred_list.append(neg_pred.cpu()) + + # Concatenate predictions and create labels + pos_pred = torch.cat(pos_pred_list, dim=0) + neg_pred = torch.cat(neg_pred_list, dim=0) y_pred = torch.cat([pos_pred, neg_pred], dim=0) - hard_thres = (y_pred.max() + y_pred.min()) / 2 - pos_y = torch.ones(pos_edge.size(1)) - neg_y = torch.zeros(neg_edge.size(1)) - y_true = torch.cat([pos_y, neg_y], dim=0) - '''self.save_pred(y_pred, y_true, eval_data)''' - pos_pred, neg_pred = y_pred[y_true == 1].cpu(), y_pred[y_true == 0].cpu() - y_pred = torch.where(y_pred >= hard_thres, torch.tensor(1), torch.tensor(0)) + hard_thres = (y_pred.max() + y_pred.min()) / 2 - y_true = y_true.clone().detach().cpu() - y_pred = y_pred.clone().detach().cpu() + # Create true labels + pos_y = torch.ones(pos_pred.size(0)) + neg_y = torch.zeros(neg_pred.size(0)) + y_true = torch.cat([pos_y, neg_y], dim=0) + # Convert predictions to binary labels + y_pred_binary = torch.where(y_pred >= hard_thres, torch.tensor(1), torch.tensor(0)) + # Move to CPU for evaluation + y_true = y_true.cpu() + y_pred_binary = y_pred_binary.cpu() - acc = torch.sum(y_true == y_pred) / len(y_true) + acc = torch.sum(y_true == y_pred_binary).item() / len(y_true) + # Assuming get_metric_score is defined elsewhere and computes metrics result_mrr = get_metric_score(self.evaluator_hit, self.evaluator_mrr, pos_pred, neg_pred) - result_mrr.update({'ACC': round(acc.tolist(), 5)}) + result_mrr.update({'ACC': round(acc, 5)}) return result_mrr @@ -224,5 +254,11 @@ def save_pred(self, pred, true, data): pred_value = pred[idx] true_value = true[idx] f.write(f"{corresponding_node_ids[0].item()} {corresponding_node_ids[1].item()} {pred_value} {true_value}\n") + def finalize(self): + import time + for _ in range(1): + start_train = time.time() + self._evaluate(self.test_data) + self.run_result['eval_time'] = time.time() - start_train diff --git a/core/yamls/arxiv_2023/gcns/elph.yaml b/core/yamls/arxiv_2023/gcns/elph.yaml index 09bde6b999..12b7f932ef 100644 --- a/core/yamls/arxiv_2023/gcns/elph.yaml +++ b/core/yamls/arxiv_2023/gcns/elph.yaml @@ -45,11 +45,11 @@ num_threads: 11 wandb: use: True project: gtblueprint - name_tag: elph-pubmed + name_tag: elph-arxiv_2023 data: - name: pubmed + name: arxiv_2023 undirected: True include_negatives: True val_pct: 0.15 diff --git a/core/yamls/arxiv_2023/gcns/ncnc.yaml b/core/yamls/arxiv_2023/gcns/ncnc.yaml index 9f8c50516a..a8d00a5cfe 100644 --- a/core/yamls/arxiv_2023/gcns/ncnc.yaml +++ b/core/yamls/arxiv_2023/gcns/ncnc.yaml @@ -42,9 +42,9 @@ model: preedp: 0.0 # - probscale: 5.0 # - proboffset: 3.0 # - pt: 0.05 # + probscale: 2.0 # + proboffset: 2.0 # + pt: 0.1 # num_threads: 11 diff --git a/core/yamls/arxiv_2023/gcns/neognn.yaml b/core/yamls/arxiv_2023/gcns/neognn.yaml index 62fc3681c3..9276c9ab1d 100644 --- a/core/yamls/arxiv_2023/gcns/neognn.yaml +++ b/core/yamls/arxiv_2023/gcns/neognn.yaml @@ -12,8 +12,8 @@ run: train: mode: custom - batch_size: 256 # - gnn_batch_size: 1024 # + batch_size: 1024 # + gnn_batch_size: 8192 # eval_period: 1 epochs: 500 device: 0 @@ -26,10 +26,10 @@ model: out_channels: 32 in_channels: 1433 hidden_channels: 256 # - num_layers: 3 # - dropout: 0.5 # + num_layers: 2 # + dropout: 0.3 # f_edge_dim: 8 # - f_node_dim: 128 # + f_node_dim: 64 # g_phi_dim: 128 # num_threads: 11 @@ -53,5 +53,5 @@ data: optimizer: type: adam - lr: 0.0001 # + lr: 0.001 # weight_decay: 0.0005 \ No newline at end of file diff --git a/core/yamls/cora/gcns/neognn.yaml b/core/yamls/cora/gcns/neognn.yaml index b338aa5c3a..b67761b9a8 100644 --- a/core/yamls/cora/gcns/neognn.yaml +++ b/core/yamls/cora/gcns/neognn.yaml @@ -12,8 +12,8 @@ run: train: mode: custom - batch_size: 256 # - gnn_batch_size: 1024 # + batch_size: 1024 # + gnn_batch_size: 4096 # eval_period: 1 epochs: 500 device: 0 @@ -25,11 +25,11 @@ model: type: NeoGNN out_channels: 32 in_channels: 1433 - hidden_channels: 256 # - num_layers: 3 # + hidden_channels: 64 # + num_layers: 2 # dropout: 0.5 # f_edge_dim: 8 # - f_node_dim: 128 # + f_node_dim: 64 # g_phi_dim: 128 # num_threads: 11 @@ -53,5 +53,5 @@ data: optimizer: type: adam - lr: 0.0001 # + lr: 0.01 # weight_decay: 0.0005 \ No newline at end of file diff --git a/core/yamls/pubmed/gcns/neognn.yaml b/core/yamls/pubmed/gcns/neognn.yaml index e7e2790a92..870f4e71fb 100644 --- a/core/yamls/pubmed/gcns/neognn.yaml +++ b/core/yamls/pubmed/gcns/neognn.yaml @@ -12,8 +12,8 @@ run: train: mode: custom - batch_size: 256 # - gnn_batch_size: 1024 # + batch_size: 512 # + gnn_batch_size: 4096 # eval_period: 1 epochs: 500 device: 0 @@ -25,11 +25,11 @@ model: type: NeoGNN out_channels: 32 in_channels: 1433 - hidden_channels: 256 # - num_layers: 3 # - dropout: 0.5 # + hidden_channels: 128 # + num_layers: 2 # + dropout: 0.1 # f_edge_dim: 8 # - f_node_dim: 128 # + f_node_dim: 64 # g_phi_dim: 128 # num_threads: 11 @@ -53,5 +53,5 @@ data: optimizer: type: adam - lr: 0.0001 # + lr: 0.001 # weight_decay: 0.0005 \ No newline at end of file