Skip to content

Commit

Permalink
symmetrize arxiv_2023
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 8, 2024
1 parent 6c6e166 commit b00be64
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 33 deletions.
67 changes: 45 additions & 22 deletions core/data_utils/load_data_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from torch_geometric.datasets import Planetoid
from torch_geometric.data import Data, InMemoryDataset, Dataset
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import to_undirected
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.preprocessing import normalize
from yacs.config import CfgNode as CN
Expand All @@ -28,6 +32,7 @@
from graphgps.utility.utils import time_logger
from typing import Dict, Tuple, List, Union


FILE = 'core/dataset/ogbn_products_orig/ogbn-products.csv'
FILE_PATH = get_git_repo_root_path() + '/'

Expand All @@ -37,8 +42,10 @@ def load_taglp_arxiv2023(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# add one default argument

data, text = load_tag_arxiv23()
undirected = data.is_directed()

if data.is_directed() is True:
data.edge_index = to_undirected(data.edge_index)
undirected = True

splits = get_edge_split(data,
undirected,
cfg.device,
Expand All @@ -56,7 +63,7 @@ def load_taglp_cora(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
data, data_citeid = load_graph_cora(False)
text = load_text_cora(data_citeid)
# text = None
undirected = data.is_directed()
undirected = data.is_undirected()

splits = get_edge_split(data,
undirected,
Expand All @@ -74,7 +81,7 @@ def load_taglp_ogbn_arxiv(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:

data = load_graph_ogbn_arxiv(False)
text = load_text_ogbn_arxiv()
undirected = data.is_directed()
undirected = data.is_undirected()

cfg = config_device(cfg)

Expand Down Expand Up @@ -115,7 +122,7 @@ def load_taglp_product(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# add one default argument

data, text = load_tag_product()
undirected = data.is_directed()
undirected = data.is_undirected()

cfg = config_device(cfg)

Expand All @@ -135,7 +142,7 @@ def load_taglp_pubmed(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:

data = load_graph_pubmed(False)
text = load_text_pubmed()
undirected = data.is_directed()
undirected = data.is_undirected()

splits = get_edge_split(data,
undirected,
Expand All @@ -152,7 +159,7 @@ def load_taglp_citeseer(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:

data = load_graph_citeseer()
text = load_text_citeseer()
undirected = data.is_directed()
undirected = data.is_undirected()

splits = get_edge_split(data,
undirected,
Expand All @@ -166,11 +173,15 @@ def load_taglp_citeseer(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:

def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# add one default argument

data = load_graph_citationv8()
text = load_text_citationv8()
undirected = data.is_directed()

if data.is_directed() is True:
data.edge_index = to_undirected(data.edge_index)
undirected = True
else:
undirected = data.is_undirected()

splits = get_edge_split(data,
undirected,
cfg.device,
Expand All @@ -185,32 +196,44 @@ def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
# TEST CODE
if __name__ == '__main__':
args = init_cfg_test()
print(args)
'''data, text, __ = load_taglp_arxiv2023(args.data)
print('arxiv2023')
splits, text, data = load_taglp_arxiv2023(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))
data, text = load_taglp_cora(args.data)

print('citationv8')
splits, text, data = load_taglp_citationv8(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))

data, text = load_taglp_product(args.data)
exit(-1)
print('cora')
splits, text, data = load_taglp_cora(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))

data, text = load_taglp_pubmed(args.data)
print('product')
splits, text, data = load_taglp_product(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))'''
print(type(text))

splits, text, data = load_taglp_citeseer(args.data)
print('pubmed')
splits, text, data = load_taglp_pubmed(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))

splits, text, data = load_taglp_citationv8(args.data)
splits, text, data = load_taglp_citeseer(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))


print(args.data)
splits, text, data = load_taglp_ogbn_arxiv(args.data)
print(f'directed: {data.is_directed()}')
print(data)
print(type(text))
print(type(text))
21 changes: 11 additions & 10 deletions core/model_finetuning/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,17 @@ def project_main():
start = time.time()
clf.partial_fit(train_dataset, train_labels, classes=classes)
print(f'this epoch costs {time.time() - start}')

# Calculate and print metrics for test set
test_metrics = get_metrics(clf, test_dataset, test_labels, evaluator_hit, evaluator_mrr)
print(test_metrics)
# Calculate and print metrics for train set
train_metrics = get_metrics(clf, train_dataset, train_labels, evaluator_hit, evaluator_mrr)
print(train_metrics)
# Calculate and print metrics for validation set
val_metrics = get_metrics(clf, val_dataset, val_labels, evaluator_hit, evaluator_mrr)
print(val_metrics)

if i % 100 == 0:
# Calculate and print metrics for test set
test_metrics = get_metrics(clf, test_dataset, test_labels, evaluator_hit, evaluator_mrr)
print(test_metrics)
# Calculate and print metrics for train set
train_metrics = get_metrics(clf, train_dataset, train_labels, evaluator_hit, evaluator_mrr)
print(train_metrics)
# Calculate and print metrics for validation set
val_metrics = get_metrics(clf, val_dataset, val_labels, evaluator_hit, evaluator_mrr)
print(val_metrics)

results_rank = {
key: (train_metrics[key], val_metrics[key], test_metrics[key])
Expand Down
2 changes: 1 addition & 1 deletion core/model_finetuning/scripts/pubmed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ module load compiler/gnu/12
cd /hkfs/work/workspace/scratch/cc7738-benchmark_tag/TAPE_chen/core/model_finetuning


for iter in 1000 2000; do
for iter in 1500; do
echo "python mlp.py --data pubmed --decoder MLP --max_iter $iter"
python mlp.py --data pubmed --decoder MLP --max_iter $iter
done

0 comments on commit b00be64

Please sign in to comment.