diff --git a/core/data_utils/load_data_lp.py b/core/data_utils/load_data_lp.py index e3faa2559f..9e86dec3fb 100644 --- a/core/data_utils/load_data_lp.py +++ b/core/data_utils/load_data_lp.py @@ -192,6 +192,26 @@ def load_taglp_citationv8(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: ) return splits, text, data +def load_taplp_pwc_large(cfg: CN) -> Tuple[Dict[str, Data], List[str]]: + data = load_graph_pwc_large() + text = load_text_pwc_large() + + if data.is_directed() is True: + data.edge_index = to_undirected(data.edge_index) + undirected = True + else: + undirected = data.is_undirected() + + splits = get_edge_split(data, + undirected, + cfg.device, + cfg.split_index[1], + cfg.split_index[2], + cfg.include_negatives, + cfg.split_labels + ) + return splits, text, data + # TEST CODE if __name__ == '__main__': diff --git a/core/data_utils/load_data_nc.py b/core/data_utils/load_data_nc.py index 39a1bdd75c..2d80a0d36d 100644 --- a/core/data_utils/load_data_nc.py +++ b/core/data_utils/load_data_nc.py @@ -483,7 +483,19 @@ def load_tag_citeseer() -> Tuple[Data, List[str]]: text = load_text_citeseer() return graph, text +def load_graph_pwc_large(): + graph = torch.load(FILE_PATH+'core/dataset/pwc_large') + return graph + +def load_text_pwc_large() -> List[str]: + df = pd.read_csv(FILE_PATH + 'core/dataset/pwc_large/pwc_large_papers.csv') + return [ + f'Text: {ti}\n' + for ti in zip(df['feat']) + ] + + # Test code if __name__ == '__main__': graph = load_graph_citeseer() diff --git a/core/model_finetuning/mlp.py b/core/model_finetuning/mlp.py index 28cfd4e1cc..0e7fcc48eb 100644 --- a/core/model_finetuning/mlp.py +++ b/core/model_finetuning/mlp.py @@ -84,7 +84,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--score', dest='score', type=str, required=False, default='mlp_score', help='decoder name') - parser.add_argument('--max_iter', dest='max_iter', type=int, required=False, default=1000, + parser.add_argument('--max_iter', dest='max_iter', type=int, required=False, default=10, help='decoder name') parser.add_argument('--repeat', type=int, default=5, @@ -167,7 +167,7 @@ def project_main(): clf.partial_fit(train_dataset, train_labels, classes=classes) print(f'this epoch costs {time.time() - start}') - if i % 100 == 0: + if i % 10 == 0: # Calculate and print metrics for test set test_metrics = get_metrics(clf, test_dataset, test_labels, evaluator_hit, evaluator_mrr) print(test_metrics)