Skip to content

Commit

Permalink
Merge branch 'main' of github.com:ChenS676/TAPE
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 8, 2024
2 parents 27aaddd + 3eec51b commit 6c6e166
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 105 deletions.
18 changes: 15 additions & 3 deletions core/data_utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@
load_tag_product,
load_tag_ogbn_arxiv,
load_tag_product,
load_tag_arxiv23)
load_tag_arxiv23,
load_tag_citeseer,
load_tag_citationv8)
from data_utils.load_data_lp import (load_taglp_arxiv2023,
load_taglp_cora,
load_taglp_pubmed,
load_taglp_product,
load_taglp_ogbn_arxiv)
load_taglp_ogbn_arxiv,
load_taglp_citeseer,
load_taglp_citationv8)
from data_utils.load_data_lp import (load_graph_cora,
load_graph_arxiv23,
load_graph_ogbn_arxiv,
load_graph_pubmed)
load_graph_pubmed,
load_graph_citeseer,
load_graph_citationv8)

# TODO standarize the input and output
load_data_nc = {
Expand All @@ -26,6 +32,8 @@
'arxiv_2023': load_tag_arxiv23,
'ogbn-arxiv': load_tag_ogbn_arxiv,
'ogbn-products': load_tag_product,
'citeseer': load_tag_citeseer,
'citationv8': load_tag_citationv8,
}

load_data_lp = {
Expand All @@ -34,13 +42,17 @@
'arxiv_2023': load_taglp_arxiv2023,
'ogbn-arxiv': load_taglp_ogbn_arxiv,
'ogbn-products': load_taglp_product,
'citeseer': load_taglp_citeseer,
'citationv8': load_taglp_citationv8,
}

load_graph_lp = {
'cora': load_graph_cora,
'pubmed': load_graph_pubmed,
'arxiv_2023': load_graph_arxiv23,
'ogbn-arxiv': load_graph_ogbn_arxiv,
'citeseer': load_graph_citeseer,
'citationv8': load_graph_citationv8,
}


Expand Down
55 changes: 50 additions & 5 deletions core/data_utils/load_data_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
load_tag_arxiv23, load_graph_cora, load_graph_pubmed, \
load_graph_arxiv23, load_graph_ogbn_arxiv, load_text_cora, \
load_text_pubmed, load_text_arxiv23, load_text_ogbn_arxiv, \
load_text_product
load_text_product, load_text_citeseer, load_text_citationv8, \
load_graph_citeseer, load_graph_citationv8

from graphgps.utility.utils import get_git_repo_root_path, config_device, init_cfg_test
from graphgps.utility.utils import time_logger
Expand Down Expand Up @@ -146,26 +147,70 @@ def load_taglp_pubmed(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
)
return splits, text, data

def load_taglp_citeseer(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# add one default argument

data = load_graph_citeseer()
text = load_text_citeseer()
undirected = data.is_directed()

splits = get_edge_split(data,
undirected,
cfg.device,
cfg.split_index[1],
cfg.split_index[2],
cfg.include_negatives,
cfg.split_labels
)
return splits, text, data

def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# add one default argument

data = load_graph_citationv8()
text = load_text_citationv8()
undirected = data.is_directed()

splits = get_edge_split(data,
undirected,
cfg.device,
cfg.split_index[1],
cfg.split_index[2],
cfg.include_negatives,
cfg.split_labels
)
return splits, text, data


# TEST CODE
if __name__ == '__main__':
args = init_cfg_test()
print(args)
data, text, __ = load_taglp_arxiv2023(args.data)
'''data, text, __ = load_taglp_arxiv2023(args.data)
print(data)
print(type(text))
data, text = load_taglp_cora(args.data)
print(data)
print(type(text))
data, text = load_taglp_ogbn_arxiv(args.data)
print(data)
print(type(text))
data, text = load_taglp_product(args.data)
print(data)
print(type(text))
data, text = load_taglp_pubmed(args.data)
print(data)
print(type(text))'''

splits, text, data = load_taglp_citeseer(args.data)
print(data)
print(type(text))

splits, text, data = load_taglp_citationv8(args.data)
print(data)
print(type(text))

splits, text, data = load_taglp_ogbn_arxiv(args.data)
print(data)
print(type(text))
65 changes: 62 additions & 3 deletions core/data_utils/load_data_nc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import dgl
import torch
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -419,7 +419,7 @@ def load_graph_ogbn_arxiv(use_mask):


def load_tag_ogbn_arxiv() -> List[str]:
graph = load_graph_ogbn_arxiv()
graph = load_graph_ogbn_arxiv(False)
text = load_text_ogbn_arxiv()
return graph, text

Expand All @@ -436,15 +436,69 @@ def load_tag_product() -> Tuple[Data, List[str]]:
return data, text


def load_graph_citationv8() -> Data:
graph = dgl.load_graphs(FILE_PATH + 'core/dataset/citationv8/Citation-2015.pt')[0][0]
graph = dgl.to_bidirected(graph)
from torch_geometric.utils import from_dgl
graph = from_dgl(graph)
return graph


def load_text_citationv8() -> List[str]:
df = pd.read_csv(FILE_PATH + 'core/dataset/citationv8_orig/Citation-2015.csv')
return [
f'Text: {ti}\n'
for ti in zip(df['text'])
]


def load_tag_citationv8() -> Tuple[Data, List[str]]:
graph = load_graph_citationv8()
text = None
train_id, val_id, test_id, train_mask, val_mask, test_mask = get_node_mask(graph.num_nodes)
graph.train_id = train_id
graph.val_id = val_id
graph.test_id = test_id
graph.train_mask = train_mask
graph.val_mask = val_mask
graph.test_mask = test_mask
return graph, text


def load_graph_citeseer() -> Data:
# load data
data_name = 'CiteSeer'
dataset = Planetoid('./generated_dataset', data_name, transform=T.NormalizeFeatures())
data = dataset[0]
return data


def load_text_citeseer() -> List[str]:

return None


def load_tag_citeseer() -> Tuple[Data, List[str]]:
graph = load_graph_citeseer()
text = load_text_citeseer()
return graph, text


# Test code
if __name__ == '__main__':
graph = load_graph_citeseer()
print(type(graph))
graph, text = load_tag_citeseer()
print(type(text))


graph = load_graph_arxiv23()
# print(type(graph))
graph, text = load_tag_arxiv23()
print(type(graph))
print(type(text))

graph, _ = load_graph_cora(True)
'''graph, _ = load_graph_cora(True)
# print(type(graph))
graph, text = load_tag_cora()
print(type(graph))
Expand All @@ -461,4 +515,9 @@ def load_tag_product() -> Tuple[Data, List[str]]:
graph = load_graph_pubmed()
graph, text = load_tag_pubmed()
print(type(graph))
print(type(text))'''

graph = load_graph_citationv8()
print(type(graph))
graph, text = load_tag_citationv8()
print(type(text))
9 changes: 4 additions & 5 deletions core/gcns/nbfnet_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,15 @@ def parse_args() -> argparse.Namespace:
f"\n Valid: {2 * splits['train']['pos_edge_label'].shape[0]} samples,"
f"\n Test: {2 * splits['test']['pos_edge_label'].shape[0]} samples")
dump_cfg(cfg)
hyperparameter_search = {'hidden_channels': [32, 64, 128, 256], 'num_layers': [3, 4, 5, 6],
hyperparameter_search = {'hidden_channels': [32, 64, 128, 256],
"batch_size": [64, 128, 256, 512, 1024], "lr": [0.01, 0.001, 0.0001]}
print_logger.info(f"hypersearch space: {hyperparameter_search}")
for hidden_channels, num_layers, batch_size, lr in tqdm(itertools.product(*hyperparameter_search.values())):
for hidden_channels, batch_size, lr in tqdm(itertools.product(*hyperparameter_search.values())):
cfg.model.hidden_channels = hidden_channels
cfg.train.batch_size = batch_size
cfg.optimizer.lr = lr
cfg.model.num_layers = num_layers
print_logger.info(
f"hidden_channels: {hidden_channels}, num_layers: {num_layers}, batch_size: {batch_size}, lr: {lr}")
f"hidden_channels: {hidden_channels}, batch_size: {batch_size}, lr: {lr}")
start_time = time.time()
model = NBFNet(cfg.model.in_channels, [cfg.model.hidden_channels] * cfg.model.num_layers, num_relation = 1)

Expand Down Expand Up @@ -159,7 +158,7 @@ def parse_args() -> argparse.Namespace:
run_result[key] = test_bvalid

run_result.update(
{'hidden_channels': hidden_channels,' num_layers': num_layers, 'batch_size': batch_size,'lr': lr,
{'hidden_channels': hidden_channels, 'batch_size': batch_size,'lr': lr,
})
print_logger.info(run_result)

Expand Down
42 changes: 34 additions & 8 deletions core/gcns/neognn_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from torch_sparse import SparseTensor

from torch_geometric.graphgym import params_count

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import argparse
import time
Expand All @@ -22,29 +24,39 @@
from graphgps.network.neognn import NeoGNN, LinkPredictor
from data_utils.load import load_data_lp
from graphgps.train.neognn_train import Trainer_NeoGNN
from graphgps.network.ncn import predictor_dict



def parse_args() -> argparse.Namespace:
r"""Parses the command line arguments."""
parser = argparse.ArgumentParser(description='GraphGym')

parser.add_argument('--cfg', dest='cfg_file', type=str, required=False,
default='core/yamls/cora/gcns/seal.yaml',
default='core/yamls/cora/gcns/ncn.yaml',
help='The configuration file path.')

parser.add_argument('--sweep', dest='sweep_file', type=str, required=False,
default='core/yamls/cora/gcns/gae_sp1.yaml',
default='core/yamls/cora/gcns/ncn.yaml',
help='The configuration file path.')
parser.add_argument('--data', dest='data', type=str, required=False,
default='cora',
parser.add_argument('--data', dest='data', type=str, required=True,
default='pubmed',
help='data name')

parser.add_argument('--repeat', type=int, default=2,
help='The number of repeated jobs.')
parser.add_argument('--batch_size', dest='bs', type=int, required=False,
default=2**15,
help='data name')
parser.add_argument('--device', dest='device', required=True,
help='device id')
parser.add_argument('--epochs', dest='epoch', type=int, required=True,
default=400,
help='data name')
parser.add_argument('--wandb', dest='wandb', required=False,
help='data name')
parser.add_argument('--mark_done', action='store_true',
help='Mark yaml as done after a job has finished.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
help='See graphgym/config.py for remaining options.')

return parser.parse_args()

def ngnn_dataset(data, splits):
Expand All @@ -69,18 +81,27 @@ def ngnn_dataset(data, splits):
cfg = set_cfg(FILE_PATH, args.cfg_file)
cfg.merge_from_list(args.opts)

cfg.data.name = args.data

cfg.data.device = args.device
cfg.model.device = args.device
cfg.device = args.device
cfg.train.epochs = args.epoch

torch.set_num_threads(cfg.num_threads)
batch_sizes = [cfg.train.batch_size]

best_acc = 0
best_params = {}
loggers = create_logger(args.repeat)
cfg.device = args.device

for batch_size in batch_sizes:
for run_id, seed, split_index in zip(
*run_loop_settings(cfg, args)):
custom_set_run_dir(cfg, run_id)
set_printing(cfg)
print_logger = set_printing(cfg)
cfg.seed = seed
cfg.run_id = run_id
seed_everything(cfg.seed)
Expand Down Expand Up @@ -113,7 +134,7 @@ def ngnn_dataset(data, splits):
run_id,
args.repeat,
loggers,
print_logger=None,
print_logger=print_logger,
batch_size=batch_size)

start = time.time()
Expand All @@ -130,3 +151,8 @@ def ngnn_dataset(data, splits):
result_dict[key] = valid_test

trainer.save_result(result_dict)

cfg.model.params = params_count(model)
print_logger.info(f'Num parameters: {cfg.model.params}')
trainer.finalize()
print_logger.info(f"Inference time: {trainer.run_result['eval_time']}")
Loading

0 comments on commit 6c6e166

Please sign in to comment.