Skip to content

Commit

Permalink
add dataloader for pwc_large
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 9, 2024
1 parent 7c4bfd0 commit eb0f1ce
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
20 changes: 20 additions & 0 deletions core/data_utils/load_data_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,26 @@ def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
)
return splits, text, data

def load_taplp_pwc_large(cfg: CN) -> Tuple[Dict[str, Data], List[str]]:
data = load_graph_pwc_large()
text = load_text_pwc_large()

if data.is_directed() is True:
data.edge_index = to_undirected(data.edge_index)
undirected = True
else:
undirected = data.is_undirected()

splits = get_edge_split(data,
undirected,
cfg.device,
cfg.split_index[1],
cfg.split_index[2],
cfg.include_negatives,
cfg.split_labels
)
return splits, text, data


# TEST CODE
if __name__ == '__main__':
Expand Down
12 changes: 12 additions & 0 deletions core/data_utils/load_data_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,19 @@ def load_tag_citeseer() -> Tuple[Data, List[str]]:
text = load_text_citeseer()
return graph, text

def load_graph_pwc_large():
graph = torch.load(FILE_PATH+'core/dataset/pwc_large')
return graph


def load_text_pwc_large() -> List[str]:
df = pd.read_csv(FILE_PATH + 'core/dataset/pwc_large/pwc_large_papers.csv')
return [
f'Text: {ti}\n'
for ti in zip(df['feat'])
]


# Test code
if __name__ == '__main__':
graph = load_graph_citeseer()
Expand Down
4 changes: 2 additions & 2 deletions core/model_finetuning/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument('--score', dest='score', type=str, required=False, default='mlp_score',
help='decoder name')

parser.add_argument('--max_iter', dest='max_iter', type=int, required=False, default=1000,
parser.add_argument('--max_iter', dest='max_iter', type=int, required=False, default=10,
help='decoder name')

parser.add_argument('--repeat', type=int, default=5,
Expand Down Expand Up @@ -167,7 +167,7 @@ def project_main():
clf.partial_fit(train_dataset, train_labels, classes=classes)
print(f'this epoch costs {time.time() - start}')

if i % 100 == 0:
if i % 10 == 0:
# Calculate and print metrics for test set
test_metrics = get_metrics(clf, test_dataset, test_labels, evaluator_hit, evaluator_mrr)
print(test_metrics)
Expand Down

0 comments on commit eb0f1ce

Please sign in to comment.