From 18a12c3591c65921d2cd2475187bb183807e83c0 Mon Sep 17 00:00:00 2001 From: Misha D Date: Tue, 3 Dec 2024 15:27:54 +0300 Subject: [PATCH 1/2] fix dataset conversion; fixed example datasets, torch geom list --- data/multiple-graphs/custom/example/raw/.info | 1 + .../custom/small/raw/small.node_attributes/a | 12 +- .../custom/small/raw/small.node_attributes/b | 12 +- data/single-graph/custom/example/raw/.info | 1 + metainfo/torch_geom_index.json | 12 +- src/aux/utils.py | 21 + src/base/custom_datasets.py | 82 +++- src/base/dataset_converter.py | 438 ++++++++++++++---- src/base/datasets_processing.py | 172 +++++-- tests/datasets_test.py | 100 +++- tests/explainers_full_test.py | 8 +- tests/models_test.py | 6 +- 12 files changed, 693 insertions(+), 172 deletions(-) diff --git a/data/multiple-graphs/custom/example/raw/.info b/data/multiple-graphs/custom/example/raw/.info index 5640a31..2824d8c 100644 --- a/data/multiple-graphs/custom/example/raw/.info +++ b/data/multiple-graphs/custom/example/raw/.info @@ -1,4 +1,5 @@ { + "name": "example", "count": 3, "nodes": [3, 4, 5], "directed": false, diff --git a/data/multiple-graphs/custom/small/raw/small.node_attributes/a b/data/multiple-graphs/custom/small/raw/small.node_attributes/a index 8cf60ee..cd96a34 100644 --- a/data/multiple-graphs/custom/small/raw/small.node_attributes/a +++ b/data/multiple-graphs/custom/small/raw/small.node_attributes/a @@ -4,8 +4,7 @@ "1": 0, "2": 1, "3": 1, - "4": 1, - "5": 0 + "4": 1 }, { "0": 1, @@ -17,8 +16,7 @@ "0": 0, "1": 1, "2": 1, - "3": 1, - "4": 0 + "3": 1 }, { "0": 0, @@ -45,8 +43,7 @@ "3": 1, "4": 1, "5": 1, - "6": 1, - "7": 0 + "6": 1 }, { "0": 1, @@ -55,8 +52,7 @@ "3": 1, "4": 1, "5": 1, - "6": 1, - "7": 0 + "6": 1 }, { "0": 0, diff --git a/data/multiple-graphs/custom/small/raw/small.node_attributes/b b/data/multiple-graphs/custom/small/raw/small.node_attributes/b index f06bd9f..71df1fe 100644 --- a/data/multiple-graphs/custom/small/raw/small.node_attributes/b +++ b/data/multiple-graphs/custom/small/raw/small.node_attributes/b @@ -4,8 +4,7 @@ "1": 1, "2": 1, "3": 0, - "4": 0, - "5": 1 + "4": 0 }, { "0": 0, @@ -17,8 +16,7 @@ "0": 0, "1": 0, "2": 0, - "3": 0, - "4": 1 + "3": 0 }, { "0": 1, @@ -45,8 +43,7 @@ "3": 1, "4": 0, "5": 1, - "6": 1, - "7": 1 + "6": 1 }, { "0": 0, @@ -55,8 +52,7 @@ "3": 0, "4": 0, "5": 1, - "6": 0, - "7": 1 + "6": 0 }, { "0": 1, diff --git a/data/single-graph/custom/example/raw/.info b/data/single-graph/custom/example/raw/.info index b64bd7d..1f76a2f 100644 --- a/data/single-graph/custom/example/raw/.info +++ b/data/single-graph/custom/example/raw/.info @@ -1,4 +1,5 @@ { + "name": "example", "count": 1, "directed": false, "nodes": [8], diff --git a/metainfo/torch_geom_index.json b/metainfo/torch_geom_index.json index 8ca68e6..34af805 100644 --- a/metainfo/torch_geom_index.json +++ b/metainfo/torch_geom_index.json @@ -9,9 +9,7 @@ "pytorch-geometric-other": [ "Actor", "BAShapes", - "Flickr", - "KarateClub", - "Reddit2" + "KarateClub" ], "Planetoid": [ "CiteSeer", @@ -32,7 +30,6 @@ "AIDS", "BZR", "BZR_MD", - "COIL-DEL", "COX2", "COX2_MD", "Cuneiform", @@ -128,13 +125,6 @@ "salicylic_acid", "toluene", "uracil" - ], - "MoleculeNet": [ - "MUV", - "PCBA", - "SIDER", - "Tox21", - "ToxCast" ] } } diff --git a/src/aux/utils.py b/src/aux/utils.py index 017951b..6084701 100644 --- a/src/aux/utils.py +++ b/src/aux/utils.py @@ -161,3 +161,24 @@ def all_subclasses( ) -> set: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in all_subclasses(c)]) + + +class tmp_dir(): + """ + Temporary create a directory near the given path. Remove it on exit. + """ + def __init__(self, path: Path): + self.path = path + from time import time + self.tmp_dir = self.path.parent / (self.path.name + str(time())) + + def __enter__(self): + self.tmp_dir.mkdir(parents=True) + return self.tmp_dir + + def __exit__(self, exception_type, exception_value, exception_traceback): + import shutil + try: + shutil.rmtree(self.tmp_dir) + except FileNotFoundError: + pass diff --git a/src/base/custom_datasets.py b/src/base/custom_datasets.py index 2be1bf5..ebe3e3a 100644 --- a/src/base/custom_datasets.py +++ b/src/base/custom_datasets.py @@ -28,7 +28,7 @@ def __init__( """ super().__init__(dataset_config) - assert self.labels_dir.exists() + # assert self.labels_dir.exists() self.info = DatasetInfo.read(self.info_path) self.node_map = None # Optional nodes mapping: node_map[i] = original id of node i self.edge_index = None @@ -68,6 +68,86 @@ def edge_index_path( """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.edge_index') + def check_validity( + self + ): + """ Check that dataset files (graph and attributes) are valid and consistent with .info. + """ + # Assuming info is OK + count = self.info.count + # Check edges + if self.is_multi(): + with open(self.edges_path, 'r') as f: + num_edges = sum(1 for _ in f) + with open(self.edge_index_path, 'r') as f: + edge_index = json.load(f) + assert all(i <= num_edges for i in edge_index) + assert num_edges == edge_index[-1] + assert count == len(edge_index) + + # Check nodes + all_nodes = [set() for _ in range(count)] # sets of nodes + if self.is_multi(): + with open(self.edges_path, 'r') as f: + start = 0 + for ix, end in enumerate(edge_index): + for _ in range(end-start): + all_nodes[ix].update(map(int, f.readline().split())) + if self.info.remap: + assert len(all_nodes[ix]) == self.info.nodes[ix] + else: + assert all_nodes[ix] == set(range(self.info.nodes[ix])) + start = end + else: + with open(self.edges_path, 'r') as f: + for line in f.readlines(): + all_nodes[0].update(map(int, line.split())) + if self.info.remap: + assert len(all_nodes[0]) == self.info.nodes[0] + else: + assert all_nodes[0] == set(range(self.info.nodes[0])) + + # Check node attributes + for ix, attr in enumerate(self.info.node_attributes["names"]): + with open(self.node_attributes_dir / attr, 'r') as f: + node_attributes = json.load(f) + if not self.is_multi(): + node_attributes = [node_attributes] + for i, attributes in enumerate(node_attributes): + assert all_nodes[i] == set(map(int, attributes.keys())) + if self.info.node_attributes["types"][ix] == "continuous": + v_min, v_max = self.info.node_attributes["values"][ix] + assert all(isinstance(v, (int, float, complex)) for v in attributes.values()) + assert min(attributes.values()) >= v_min + assert max(attributes.values()) <= v_max + elif self.info.node_attributes["types"][ix] == "categorical": + assert set(attributes.values()).issubset(set(self.info.node_attributes["values"][ix])) + + # Check edge attributes + for ix, attr in enumerate(self.info.edge_attributes["names"]): + with open(self.edge_attributes_dir / attr, 'r') as f: + edge_attributes = json.load(f) + if not self.is_multi(): + edge_attributes = [edge_attributes] + for i, attributes in enumerate(edge_attributes): + # TODO check edges + if self.info.edge_attributes["types"][ix] == "continuous": + v_min, v_max = self.info.edge_attributes["values"][ix] + assert all(isinstance(v, (int, float, complex)) for v in attributes.values()) + assert min(attributes.values()) >= v_min + assert max(attributes.values()) <= v_max + elif self.info.edge_attributes["types"][ix] == "categorical": + assert set(attributes.values()).issubset(set(self.info.edge_attributes["values"][ix])) + + # Check labels + for labelling, n_classes in self.info.labelings.items(): + with open(self.labels_dir / labelling, 'r') as f: + labels = json.load(f) + if self.is_multi(): # graph labels + assert set(range(count)) == set(map(int, labels.keys())) + else: # nodes labels + assert all_nodes[0] == set(map(int, labels.keys())) + def build( self, dataset_var_config: Union[ConfigPattern, DatasetVarConfig] diff --git a/src/base/dataset_converter.py b/src/base/dataset_converter.py index 93fb6f2..9a71a5a 100644 --- a/src/base/dataset_converter.py +++ b/src/base/dataset_converter.py @@ -1,8 +1,175 @@ -import os +import json +from pathlib import Path from torch_geometric.utils import to_networkx, from_networkx import networkx as nx + +from base.datasets_processing import DatasetInfo from base.ptg_datasets import is_graph_directed -from src.aux.utils import GRAPHS_DIR + + +class DatasetConverter: + """ + Converts graph data from one format to another. + """ + supported_formats = ["adjlist", "edgelist", "gml", "g6", "s6"] + + @staticmethod + def format_to_ij( + info: DatasetInfo, + graph_files: [Path], + format: str, + output_dir: Path, + default_node_attr_value: dict = None, + default_edge_attr_value: dict = None, + ) -> None: + """ + Convert graph data to our 'ij' format. + """ + assert format in DatasetConverter.supported_formats + if format in ['.g6', '.s6']: + if info.remap: + raise RuntimeError(f"Graphs in '{format}' format don't support nodes remapping," + f" nodes must be enumerated from 0 to N-1.") + + # Read to networkx + create_using = nx.DiGraph if info.directed else nx.Graph + graphs = [] + for path in graph_files: + graph = read_nx_graph(format, path, create_using=create_using) + graphs.append(graph) + + assert len(graphs) == info.count + + # Extract attributes + all_node_attributes = [] + all_edge_attributes = [] + for graph in graphs: + node_attributes, edge_attributes = extract_attributes( + graph, default_node_attr_value, default_edge_attr_value) + all_node_attributes.append(node_attributes) + all_edge_attributes.append(edge_attributes) + + # Write graphs and attributes to output dir + with open(output_dir / f'{info.name}.ij', 'w') as f: + for graph in graphs: + for i, j in graph.edges: + f.write(f'{i} {j}\n') + if len(graphs) > 1: + edge_index = [] + edges = 0 + for graph in graphs: + edges += graph.number_of_edges() + edge_index.append(edges) + with open(output_dir / f'{info.name}.edge_index', 'w') as f: + json.dump(edge_index, f) + + node_attr_dir = output_dir / f'{info.name}.node_attributes' + node_attr_dir.mkdir() + edge_attr_dir = output_dir / f'{info.name}.edge_attributes' + edge_attr_dir.mkdir() + if len(graphs) == 1: + all_node_attributes = all_node_attributes[0] + all_edge_attributes = all_edge_attributes[0] + for attr, data in all_node_attributes.items(): + with open(node_attr_dir / attr, 'w') as f: + json.dump(data, f) + for attr, data in all_edge_attributes.items(): + with open(edge_attr_dir / attr, 'w') as f: + json.dump(data, f) + else: + keys = set.union(*[set(a.keys()) for a in all_node_attributes]) + for attr in keys: + with open(node_attr_dir / attr, 'w') as f: + json.dump([a[attr] for a in all_node_attributes], f) + keys = set.union(*[set(a.keys()) for a in all_edge_attributes]) + for attr in keys: + with open(edge_attr_dir / attr, 'w') as f: + json.dump([a[attr] for a in all_edge_attributes], f) + + # FIXME add new attri to DatasetInfo? + + @staticmethod + def networkx_to_format( + graph: nx.Graph, + format: str, + output_dir: Path, + default_node_attr_value: dict = None, + default_edge_attr_value: dict = None, + name: str = 'networkx_graph' + ) -> None: + """ + Write a networkx graph to files according to a specified format. + Attribute files will be created if necessary. + """ + node_attributes, edge_attributes = extract_attributes( + graph, default_node_attr_value, default_edge_attr_value) + + graph_file = output_dir / f'{name}.{format}' + node_attrs_ok = False + edge_attrs_ok = False + if format == "adjlist": + nx.write_adjlist(graph, graph_file) + + elif format == "edgelist": + nx.write_edgelist(graph, graph_file) + + elif format == "gml": + nx.write_gml(graph, graph_file) + node_attrs_ok = True + edge_attrs_ok = True + + elif format == "g6": + nx.write_graph6(graph, str(graph_file)) + + elif format == "s6": + nx.write_sparse6(graph, graph_file) + + else: + raise NotImplementedError + + node_attributes_dir = output_dir / f'{name}.node_attributes' + edge_attributes_dir = output_dir / f'{name}.edge_attributes' + + if not node_attrs_ok: + node_attributes_dir.mkdir() + for attr, data in node_attributes.items(): + with open(node_attributes_dir / str(attr), 'w') as f: + json.dump(node_attributes[attr], f) + if not edge_attrs_ok: + edge_attributes_dir.mkdir() + for attr, data in edge_attributes.items(): + with open(edge_attributes_dir / str(attr), 'w') as f: + json.dump(edge_attributes[attr], f) + + +def extract_attributes( + graph: nx.Graph, + default_node_attr_value: dict = None, + default_edge_attr_value: dict = None, +) -> (dict, dict): + """ + Extract nodes and edges attributes from a networkx graph. + """ + all_node_attributes_names = set() + all_edge_attributes_names = set() + for node in graph.nodes(data=True): + all_node_attributes_names.update(node[1].keys()) + for edge in graph.edges(data=True): + all_edge_attributes_names.update(edge[2].keys()) + node_attributes = {attr: {} for attr in all_node_attributes_names} + edge_attributes = {attr: {} for attr in all_edge_attributes_names} + # for attr in all_node_attributes_names: + # node_attributes[attr] = nx.get_node_attributes(graph, attr, default_node_attr_value[attr]) + # for attr in all_edge_attributes_names: + # edge_attributes[attr] = nx.get_edge_attributes(graph, attr, default_edge_attr_value[attr]) + for n, data in graph.nodes(data=True): + for attr in all_node_attributes_names: + node_attributes[attr][n] = data[attr] if attr in data else default_node_attr_value[attr] + for i, j, data in graph.edges(data=True): + for attr in all_edge_attributes_names: + edge_attributes[attr][f"{i},{j}"] = data[attr] if attr in data else default_edge_attr_value[attr] + # edge_attributes[attr][f"{i},{j}"] = data.get(attr, default_edge_attr_value[attr]) + return node_attributes, edge_attributes def ptg_to_networkx(ptg_graph): @@ -29,28 +196,23 @@ def networkx_to_ptg(nx_graph): for attribute_name in data: node_attribute_names.add(attribute_name) - # Checking that attributes have numeric types - for attr in node_attribute_names: - attr_name = nx.get_node_attributes(nx_graph, attr) - for name in attr_name: - if not isinstance(attr_name[name], (int, float, complex, list)): - raise RuntimeError("Wrong NODE attribute type!!!") - # Iterating through the edges and collect unique attribute names for u, v, data in nx_graph.edges(data=True): for attribute_name in data: edge_attribute_names.add(attribute_name) - # Checking that attributes have numeric types - for attr in edge_attribute_names: - attr_name = nx.get_edge_attributes(nx_graph, attr) - for name in attr_name: - if not isinstance(attr_name[name], (int, float, complex)): - raise RuntimeError("Wrong EDGE attribute type!!!", nx_graph) + node_attribute_names_list = [] + edge_attribute_names_list = [] + # Get only attributes that have numeric types + for attr in sorted(node_attribute_names): + if all(isinstance(v, (int, float, complex, list)) + for v in nx.get_node_attributes(nx_graph, attr).values()): + node_attribute_names_list.append(attr) - # Converting the set of unique attribute names to a list - node_attribute_names_list = list(node_attribute_names) - edge_attribute_names_list = list(edge_attribute_names) + for attr in sorted(edge_attribute_names): + if all(isinstance(v, (int, float, complex, list)) + for v in nx.get_edge_attributes(nx_graph, attr).values()): + edge_attribute_names_list.append(attr) if len(node_attribute_names_list) < 1: node_attribute_names_list = None @@ -62,52 +224,52 @@ def networkx_to_ptg(nx_graph): return ptg_graph -def read_nx_graph(data_format, path): +def read_nx_graph(data_format, path, **kwargs): # FORMATS THAT ARE NOT SUPPORTED: # gexf, multiline_adjlist, weighted_edgelist - if data_format == ".adjlist": # This format does not store graph or node attributes. - return nx.read_adjlist(path) - elif data_format == ".edgelist": - return nx.read_edgelist(path) - elif data_format == ".gml": # Only works with graphs that have node, edge attributes + if data_format == "adjlist": # This format does not store graph or node attributes. + return nx.read_adjlist(path, **kwargs) + elif data_format == "edgelist": + return nx.read_edgelist(path, **kwargs) + elif data_format == "gml": # Only works with graphs that have node, edge attributes return nx.read_gml(path) # # GRAPHML DOESN'T WORK WITH from_networkx() # elif data_format == "graphml": - # return nx.read_graphml(path) + # return nx.read_graphml(path, **kwargs) # # LEDA format is not supported as it stores edge attributes as strings # elif data_format == "leda": - # return nx.read_leda(path) - elif data_format == ".g6": + # return nx.read_leda(path, **kwargs) + elif data_format == "g6": return nx.read_graph6(path) - elif data_format == ".s6": + elif data_format == "s6": return nx.read_sparse6(path) # # PAJEK format is not supported as it stores node attributes as strings # elif data_format == "pajek": # Only works with graphs that have node labels - # return nx.read_pajek(path) + # return nx.read_pajek(path, **kwargs) else: raise RuntimeError("the READING format is NOT SUPPORTED!!!") def write_nx_graph(graph, data_format, path): - if data_format == ".adjlist": + if data_format == "adjlist": return nx.write_adjlist(graph, path) # elif data_format == "multiline_adjlist": # return nx.write_multiline_adjlist(graph, path) - elif data_format == ".edgelist": + elif data_format == "edgelist": return nx.write_edgelist(graph, path) # elif data_format == "weighted_edgelist": # return nx.write_weighted_edgelist(graph, path) # elif data_format == "gexf": # return nx.write_gexf(graph, path) - elif data_format == ".gml": + elif data_format == "gml": return nx.write_gml(graph, path) # elif data_format == "graphml": # return nx.write_graphml(graph, path) # elif data_format == "leda": # return nx.write_leda(graph, path) - elif data_format == ".g6": + elif data_format == "g6": return nx.write_graph6(graph, path) - elif data_format == ".s6": + elif data_format == "s6": return nx.write_sparse6(graph, path) # elif data_format == "pajek": # return nx.write_pajek(graph, path) @@ -115,66 +277,154 @@ def write_nx_graph(graph, data_format, path): raise RuntimeError("the WRITING format is NOT SUPPORTED!!!") -# Reading NX graphs from files in given formats -def read_nx_graphs(format_list, path_list): - # Creating a dict where keys are formats and values are lists of graphs - nx_graphs_dict = {format: [] for format in format_list} - for format in format_list: - for path in path_list: - if path.endswith(format): - nx_graph = read_nx_graph(format, path) - if (type(nx_graph) == list): - for graph in nx_graph: - nx_graphs_dict[format].append(graph) - break - nx_graphs_dict[format].append(nx_graph) - break - return nx_graphs_dict - - -def converting_func(nx_graphs_dict): - ptg_graphs_dict = {format: [] for format in nx_graphs_dict.keys()} - for format in nx_graphs_dict.keys(): - for graph in nx_graphs_dict[format]: - ptg_graph = networkx_to_ptg(graph) - ptg_graphs_dict[format].append(ptg_graph) - - new_nx_graphs_dict = {format: [] for format in ptg_graphs_dict.keys()} - for format in ptg_graphs_dict.keys(): - for graph in ptg_graphs_dict[format]: - nx_graph = ptg_to_networkx(graph) - new_nx_graphs_dict[format].append(nx_graph) - - return new_nx_graphs_dict +def example_single(): + g = nx.Graph() + g.add_node(11, a=0.4, b=100) + g.add_node(12, a=0.3, b=50) + g.add_node(13, a=0.3, b=200) + g.add_node(14, a=0.2, b=75) + g.add_node(15, a=0.4, b=25) + g.add_node(16, a=0.2, b=150) + g.add_node(17, a=0.5, b=80) + g.add_node(18, a=0.1, b=40) + g.add_edge(11, 12, weight=5, type='big') + g.add_edge(11, 13, weight=3, type='medium') + g.add_edge(11, 14, weight=4, type='small') + g.add_edge(12, 15, weight=2, type='big') + g.add_edge(12, 16, weight=6, type='big') + g.add_edge(13, 14, weight=3, type='medium') + g.add_edge(13, 17, weight=5, type='small') + g.add_edge(14, 18, weight=4, type='big') + g.add_edge(15, 16, weight=1, type='small') + g.add_edge(16, 17, weight=5, type='small') + g.add_edge(17, 18, weight=3, type='medium') + + from aux.configs import DatasetConfig + from base.datasets_processing import DatasetManager + from aux.declaration import Declare + + name = 'example_gml' + dc = DatasetConfig('single-graph', 'custom', name) + + # Create directory + root, files_paths = Declare.dataset_root_dir(dc) + raw = root / 'raw' + raw.mkdir(parents=True) + + # Write info and labels + nx.write_gml(g, raw / 'graph.gml') + with open(raw / '.info', 'w') as f: + json.dump({ + "name": name, + "count": 1, + "directed": True, + "nodes": [g.number_of_nodes()], + "remap": True, + "node_attributes": { + "names": ["a", "b"], + "types": ["continuous", "continuous"], + "values": [[0, 1], [0, 200]] + }, + "edge_attributes": { + "names": ["weight", "type"], + "types": ["continuous", "categorical"], + "values": [[1, 6], ['small', 'medium', 'big']] + }, + "labelings": {"binary": 2} + }, f) + + (raw / f'{name}.labels').mkdir() + with open(raw / f'{name}.labels' / 'binary', 'w') as f: + json.dump({"11": 1, "12": 0, "13": 0, "14": 0, "15": 0, "16": 0, "17": 0, "18": 0}, f) + + custom_dataset = DatasetManager.register_custom( + dc, 'gml', + default_node_attr_value={'a': -1, 'b': -1}, + default_edge_attr_value={'weight': -1, 'type': -1}) + custom_dataset.check_validity() + + +def example_multi(): + # Multi + g1 = nx.Graph() + g1.add_node(1, a=10, b='alpha') + g1.add_node(2, a=20, b='beta') + g1.add_node(3, a=30, b='gamma') + g1.add_node(4, a=40, b='delta') + g1.add_edge(1, 2, weight=1.5, type='mixed') + g1.add_edge(2, 3, weight=2.7) + g1.add_edge(3, 4, type='complex') + g1.add_edge(1, 4, weight=0.9, type='hybrid') + + g2 = nx.Graph() + g2.add_node(1, a=15, b='alpha') + g2.add_node(2, a=25, b='beta') + g2.add_node(3, a=35, b='gamma') + g2.add_node(4, a=45, b='delta') + g2.add_edge(1, 2, weight=1.2, type='mixed') + g2.add_edge(2, 3, weight=2.3, type='complex') + g2.add_edge(3, 4, weight=3.4) + g2.add_edge(1, 4, weight=4.5, type='hybrid') + + g3 = nx.Graph() + g3.add_node(1, a=20, b='alpha') + g3.add_node(2, a=30, b='beta') + g3.add_node(3, a=40, b='gamma') + g3.add_node(4, a=50, b='delta') + g3.add_node(5, a=60) + g3.add_edge(1, 2, weight=1.8, type='mixed') + g3.add_edge(2, 3) + g3.add_edge(3, 4, weight=2.5, type='complex') + g3.add_edge(4, 5, type='hybrid') + g3.add_edge(1, 5, weight=3.2) + + from aux.configs import DatasetConfig + from base.datasets_processing import DatasetManager + from aux.declaration import Declare + + name = 'example_gml' + dc = DatasetConfig('multiple-graphs', 'custom', name) + + # Create directory + root, files_paths = Declare.dataset_root_dir(dc) + raw = root / 'raw' + raw.mkdir(parents=True) + + # Write info and labels + nx.write_gml(g1, raw / 'graph1.gml') + nx.write_gml(g2, raw / 'graph2.gml') + nx.write_gml(g3, raw / 'graph3.gml') + with open(raw / '.info', 'w') as f: + json.dump({ + "name": name, + "count": 3, + "directed": False, + "nodes": [g.number_of_nodes() for g in [g1,g2,g3]], + "remap": True, + "node_attributes": { + "names": ["a", "b"], + "types": ["continuous", "categorical"], + "values": [[0, 100], ['alpha', 'beta', 'gamma', 'delta']] + }, + "edge_attributes": { + "names": ["weight", "type"], + "types": ["continuous", "categorical"], + "values": [[0, 5], ['mixed', 'complex', 'hybrid']] + }, + "labelings": {"binary": 2} + }, f) + + (raw / f'{name}.labels').mkdir() + with open(raw / f'{name}.labels' / 'binary', 'w') as f: + json.dump({"0":1,"1":0,"2":0}, f) + + custom_dataset = DatasetManager.register_custom( + dc, 'gml', + default_node_attr_value={'a': 0, 'b': 'alpha'}, + default_edge_attr_value={'weight': 1, 'type': 'mixed'}) + custom_dataset.check_validity() -# Writing graphs from graphs_dict into given directory, file names will be "./output_graph." -def write_nx_graphs(graphs_dict, output_dir): # "output_dir" should be a str - for format in graphs_dict.keys(): - i = 0 - for graph in graphs_dict[format]: - write_nx_graph(graph, format, f"{output_dir}/output_graph{i}{format}") - i += 1 - return - - -# Extracting paths from a given directory -def path_handler(dir_path): - path_list = [] - for file_path in os.listdir(dir_path): - if os.path.isfile(os.path.join(dir_path, file_path)): - path_list.append(os.path.join(dir_path, file_path)) - return path_list - - -formats_list = [".adjlist", ".edgelist", ".gml", ".g6", ".s6"] - -input_dir = GRAPHS_DIR / 'networkx-graphs' / 'input' -output_dir = GRAPHS_DIR / 'networkx-graphs' / 'output' - if __name__ == '__main__': - input_paths = path_handler(input_dir) - nx_graphs_dict = read_nx_graphs(formats_list, input_paths) - new_nx_graphs_dict = converting_func(nx_graphs_dict) - - write_nx_graphs(new_nx_graphs_dict, str(output_dir)) + # example_single() + example_multi() diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index 1fce165..d0c9e11 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -1,4 +1,5 @@ import json +import shutil import os from pathlib import Path from typing import Union, Type @@ -11,19 +12,19 @@ from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from aux.custom_decorators import timing_decorator from aux.declaration import Declare -from aux.utils import TORCH_GEOM_GRAPHS_PATH +from aux.utils import TORCH_GEOM_GRAPHS_PATH, tmp_dir class DatasetInfo: """ - Description for a dataset family. + Description for a dataset. Some fields are obligate, others are not. """ def __init__( self ): - self.name: str = "" + self.name: str = None self.count: int = None self.directed: bool = None self.nodes: list = None @@ -36,6 +37,7 @@ def __init__( } self.labelings: dict = None self.node_attr_slices: dict = None + self.edge_attr_slices: dict = None self.node_info: dict = {} self.edge_info: dict = {} self.graph_info: dict = {} @@ -76,7 +78,7 @@ def check_consistency( def check_sufficiency( self ) -> None: - """ Check all obligates fields are defined. """ + """ Check all obligate fields are defined. """ for attr in self.__dict__.keys(): if attr is None: raise ValueError(f"Attribute '{attr}' of metainfo should be defined.") @@ -85,7 +87,7 @@ def check_consistency_with_dataset( self, dataset: Dataset ) -> None: - """ Check if metainfo fields are consistent with dataset. """ + """ Check if metainfo fields are consistent with PTG dataset. """ assert self.count == len(dataset) from base.ptg_datasets import is_graph_directed assert self.directed == is_graph_directed(dataset.get(0)) @@ -122,7 +124,8 @@ def save( def induce( dataset: Dataset ): - """ Induce metainfo from a given PTG dataset. """ + """ Induce metainfo from a given PTG dataset. + """ res = DatasetInfo() res.count = len(dataset) from base.ptg_datasets import is_graph_directed @@ -134,7 +137,7 @@ def induce( "values": [len(dataset.get(0).x[0])] } res.labelings = {"origin": dataset.num_classes} - res.node_attr_slices = res.node_attributes_to_node_attr_slices(res.node_attributes) + res.node_attr_slices = res.get_attributes_slices_form_attributes(res.node_attributes, res.edge_attributes) res.check() return res @@ -148,27 +151,45 @@ def read( res = DatasetInfo() for k, v in a_dict.items(): setattr(res, k, v) - res.node_attr_slices = res.node_attributes_to_node_attr_slices(res.node_attributes) + res.node_attr_slices, res.edge_attr_slices = res.get_attributes_slices_form_attributes( + res.node_attributes, res.edge_attributes) res.check() return res @staticmethod - def node_attributes_to_node_attr_slices( - node_attributes: dict - ) -> dict: + def get_attributes_slices_form_attributes( + node_attributes: dict, + edge_attributes: dict, + ) -> (dict, dict): node_attr_slices = {} - start_attr_index = 0 - for i in range(len(node_attributes['names'])): - if node_attributes['types'][i] == 'other': - attr_len = node_attributes['values'][i] - elif node_attributes['types'][i] == 'categorical': - attr_len = len(node_attributes['values'][i]) - else: - attr_len = 1 - node_attr_slices[node_attributes['names'][i]] = ( - start_attr_index, start_attr_index + attr_len) - start_attr_index = start_attr_index + attr_len - return node_attr_slices + if node_attributes: + start_attr_index = 0 + for i in range(len(node_attributes['names'])): + if node_attributes['types'][i] == 'other': + attr_len = node_attributes['values'][i] + elif node_attributes['types'][i] == 'categorical': + attr_len = len(node_attributes['values'][i]) + else: + attr_len = 1 + node_attr_slices[node_attributes['names'][i]] = ( + start_attr_index, start_attr_index + attr_len) + start_attr_index = start_attr_index + attr_len + + edge_attr_slices = {} + if edge_attributes: + start_attr_index = 0 + for i in range(len(edge_attributes['names'])): + if edge_attributes['types'][i] == 'other': + attr_len = edge_attributes['values'][i] + elif edge_attributes['types'][i] == 'categorical': + attr_len = len(edge_attributes['values'][i]) + else: + attr_len = 1 + edge_attr_slices[edge_attributes['names'][i]] = ( + start_attr_index, start_attr_index + attr_len) + start_attr_index = start_attr_index + attr_len + + return node_attr_slices, edge_attr_slices class VisiblePart: @@ -831,7 +852,7 @@ def get_by_config( gen_dataset = CustomDataset(dataset_config) elif dataset_group in ["vk_samples"]: - # TODO misha - it is a kind of custom? + # FIXME misha - it is a kind of custom? from base.vk_datasets import VKDataset gen_dataset = VKDataset(dataset_config) @@ -905,13 +926,69 @@ def register_torch_geometric_api( return gen_dataset @staticmethod - def register_custom_ij( - path: Path + def register_custom( + dataset_config: DatasetConfig, + format: str = 'ij', + default_node_attr_value: dict = None, + default_edge_attr_value: dict = None, ) -> GeneralDataset: """ - :return: GeneralDataset + Create GeneralDataset from user created files in one of the supported formats. + Attribute files created by user have priority over attributes extracted from the graph file. + + :param dataset_config: config for a new dataset. Files will be searched for in the folder + defined by this config. + :param format: one of the supported formats. + :param default_node_attr_value: dict with default node attributes values to apply where + missing. + :param default_edge_attr_value: dict with default edge attributes values to apply where + missing. + :return: CustomDataset """ - # TODO misha + # Create empty CustomDataset + from base.custom_datasets import CustomDataset + gen_dataset = CustomDataset(dataset_config) + + # Look for obligate files: .info, graph(s), a dir with labels + # info_file = None + label_dir = None + graph_files = [] + path = gen_dataset.raw_dir + for p in path.iterdir(): + # if p.is_file() and p.name == '.info': + # info_file = p + if p.is_file() and p.name.endswith(f'.{format}'): + graph_files.append(p) + if p.is_dir() and p.name.endswith('.labels'): + label_dir = p + # if info_file is None: + # raise RuntimeError(f"No .info file was found at {path}") + if len(graph_files) == 0: + raise RuntimeError(f"No files with extension '.{format}' found at {path}") + if label_dir is None: + raise RuntimeError(f"No file with extension '.label' found at {path}") + + # Order of files is important, should be consistent with .info, we suppose they are sorted + graph_files = sorted(graph_files) + + # Create a temporary dir to store converted data + with tmp_dir(path) as tmp: + # Convert the data if necessary, write it to an empty directory + if format != 'ij': + from base.dataset_converter import DatasetConverter + DatasetConverter.format_to_ij(gen_dataset.info, graph_files, format, tmp, + default_node_attr_value, default_edge_attr_value) + + # Move or copy original contents to a temporary dir + merge_directories(path, tmp, True) + + # Rename the newly created dir to the original one + tmp.rename(path) + + # Check that data is valid + gen_dataset.check_validity() + + return gen_dataset @staticmethod def _register_torch_geometric( @@ -931,7 +1008,7 @@ def _register_torch_geometric( will be overwritten. :param copy_data: if True processed data will be copied, otherwise a symbolic link is created. - :return: dataset_config + :return: GeneralDataset """ info = DatasetInfo.induce(dataset) if name is None: @@ -948,7 +1025,6 @@ def _register_torch_geometric( ) # Check if exists - import shutil root_dir, files_paths = Declare.dataset_root_dir(dataset_config) if root_dir.exists(): if exists_ok: @@ -963,16 +1039,6 @@ def _register_torch_geometric( # Link or copy original contents to our path results_dir = gen_dataset.results_dir - # if results_dir.exists(): - # if not exists_ok: - # raise FileExistsError(f"Graph with config {dataset_config} already exists!") - # else: - # # Clear directory to avoid copying files to a directory linking to those files - # if results_dir.is_symlink(): - # os.unlink(results_dir) - # else: - # shutil.rmtree(results_dir) - results_dir.parent.mkdir(parents=True, exist_ok=True) if copy_data: shutil.copytree(os.path.abspath(dataset.processed_dir), results_dir, @@ -988,6 +1054,32 @@ def _register_torch_geometric( return gen_dataset +def merge_directories(source_dir, destination_dir, remove_source=False): + """ + Merge source directory into destination directory, replacing existing files. + + :param source_dir: Path to the source directory to be merged + :param destination_dir: Path to the destination directory + :param remove_source: if True, remove source directory (empty folders) + """ + for root, _, files in os.walk(source_dir): + # Calculate relative path + relative_path = os.path.relpath(root, source_dir) + + # Create destination path + dest_path = os.path.join(destination_dir, relative_path) + os.makedirs(dest_path, exist_ok=True) + + # Move files + for file in files: + src_file = os.path.join(root, file) + dest_file = os.path.join(dest_path, file) + shutil.move(src_file, dest_file) + + if remove_source: + shutil.rmtree(source_dir) + + def is_in_torch_geometric_datasets( full_name: tuple = None ) -> bool: diff --git a/tests/datasets_test.py b/tests/datasets_test.py index 3833377..5bc16ad 100644 --- a/tests/datasets_test.py +++ b/tests/datasets_test.py @@ -1,4 +1,3 @@ -import collections import collections.abc collections.Callable = collections.abc.Callable @@ -10,15 +9,16 @@ from torch import tensor from torch_geometric.data import InMemoryDataset, Data, Dataset -# Monkey path GRAPHS_DIR - before other imports +# Monkey patch GRAPHS_DIR - before other imports from aux import utils - if not str(utils.GRAPHS_DIR).endswith("__DatasetsTest_tmp"): tmp_dir = utils.GRAPHS_DIR.parent / (utils.GRAPHS_DIR.name + "__DatasetsTest_tmp") utils.GRAPHS_DIR = tmp_dir else: tmp_dir = utils.GRAPHS_DIR +from base.dataset_converter import networkx_to_ptg + def my_ctrlc_handler(signal, frame): print('my_ctrlc_handler', tmp_dir, tmp_dir.exists()) @@ -323,6 +323,100 @@ def test_custom_ij_multi(self): self.assertTrue(gen_dataset.num_classes, 2) self.assertTrue(gen_dataset.num_node_features, 3) + def test_custom_other_single(self): + """ """ + from aux.configs import DatasetVarConfig + from aux.configs import DatasetConfig + from aux.declaration import Declare + from base.custom_datasets import CustomDataset + from base.dataset_converter import DatasetConverter + from base.datasets_processing import DatasetManager + import json + import networkx as nx + + g = nx.Graph() + g.add_node(0, a=0.4, b=100) + g.add_node(1, a=0.4, b=100) + g.add_node(2, a=0.3, b=50) + g.add_node(3, a=0.3, b=200) + g.add_node(4, a=0.2, b=75) + g.add_node(5, a=0.4, b=25) + g.add_node(6, a=0.2, b=150) + g.add_node(7, a=0.5, b=80) + g.add_node(8, a=0.1, b=40) + g.add_edge(0, 1, weight=5, type='big') + g.add_edge(1, 2, weight=5, type='big') + g.add_edge(1, 3, weight=3, type='medium') + g.add_edge(1, 4, weight=4, type='small') + g.add_edge(2, 5, weight=2, type='big') + g.add_edge(2, 6, weight=6, type='big') + g.add_edge(3, 4, weight=3, type='medium') + g.add_edge(3, 7, weight=5, type='small') + g.add_edge(4, 8, weight=4, type='big') + g.add_edge(5, 6, weight=1, type='small') + g.add_edge(6, 7, weight=5, type='small') + g.add_edge(7, 8, weight=3, type='medium') + + node_labels = {0: 0, 1: 0, 2: 0, 3: 1, 4: 0, 5: 1, 6: 1, 7: 0, 8: 0} + + true_ptg_data = networkx_to_ptg(g) + + # for format in ['.g6']: + for format in DatasetConverter.supported_formats: + print(f"Checking format {format}") + + name = f'test_{format}' + dc = DatasetConfig('single-graph', 'custom', name) + + # Write graph and attributes files + root, files_paths = Declare.dataset_root_dir(dc) + raw = root / 'raw' + raw.mkdir(parents=True) + DatasetConverter.networkx_to_format(g, format, raw, name=name, + default_node_attr_value={'a': -1, 'b': -1}, + default_edge_attr_value={'weight': -1, 'type': -1}) + + # Write info and labels + with open(raw / '.info', 'w') as f: + json.dump({ + "name": name, + "count": 1, + "directed": False, + "nodes": [g.number_of_nodes()], + "remap": False, + "node_attributes": { + "names": ["a", "b"], + "types": ["continuous", "continuous"], + "values": [[0, 1], [0, 200]] + }, + "edge_attributes": { + "names": ["weight", "type"], + "types": ["continuous", "categorical"], + "values": [[1, 6], ['small', 'medium', 'big']] + }, + "labelings": {"binary": 2} + }, f) + + (raw / f'{name}.labels').mkdir() + with open(raw / f'{name}.labels' / 'binary', 'w') as f: + json.dump(node_labels, f) + + # Convert from the format + gen_dataset = DatasetManager.register_custom(dc, format) + + dataset_var_config = DatasetVarConfig( + features={'attr': {'a': 'as_is', 'b': 'as_is'}}, labeling='binary', dataset_ver_ind=0) + gen_dataset.build(dataset_var_config) + ptg_data = gen_dataset.data + + # Check features and edges coincide + self.assertTrue(torch.equal(true_ptg_data.x.sort(dim=0)[0], ptg_data.x.sort(dim=0)[0])) + sorted_edges1 = torch.sort(true_ptg_data.edge_index, dim=1)[0] + sorted_edges2 = torch.sort(ptg_data.edge_index, dim=1)[0] + self.assertTrue(torch.equal(sorted_edges1, sorted_edges2)) + # FIXME add it later when edge features are ready + # self.assertTrue(torch.equal(true_ptg_data.edge_attr, ptg_data.edge_attr)) + def test_ptg_lib(self): """ NOTE: takes a lot of time """ diff --git a/tests/explainers_full_test.py b/tests/explainers_full_test.py index dc3f905..6a4cf37 100644 --- a/tests/explainers_full_test.py +++ b/tests/explainers_full_test.py @@ -8,7 +8,11 @@ import signal from time import time +# Monkey patch EXPLANATIONS_DIR - before other imports from aux import utils +tmp_dir = utils.EXPLANATIONS_DIR / (utils.EXPLANATIONS_DIR.name + str(time())) +utils.EXPLANATIONS_DIR = tmp_dir + from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \ EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH from base.datasets_processing import DatasetManager @@ -29,10 +33,6 @@ # from src.models_builder.models_zoo import model_configs_zoo -tmp_dir = utils.EXPLANATIONS_DIR / (utils.EXPLANATIONS_DIR.name + str(time())) -utils.EXPLANATIONS_DIR = tmp_dir - - def my_ctrlc_handler(signal, frame): if tmp_dir.exists(): shutil.rmtree(tmp_dir) diff --git a/tests/models_test.py b/tests/models_test.py index c4e117e..b890f57 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -6,16 +6,16 @@ import signal from time import time +# Monkey patch MODELS_DIR - before other imports from aux import utils +tmp_dir = utils.MODELS_DIR / (utils.MODELS_DIR.name + str(time())) +utils.MODELS_DIR = tmp_dir from base.datasets_processing import DatasetManager from models_builder.gnn_models import FrameworkGNNModelManager, ProtGNNModelManager, Metric from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern from models_builder.models_zoo import model_configs_zoo -tmp_dir = utils.MODELS_DIR / (utils.MODELS_DIR.name + str(time())) -utils.MODELS_DIR = tmp_dir - def my_ctrlc_handler(signal, frame): if tmp_dir.exists(): From 9cdb678fa569045edf435688161bb59380dcdf74 Mon Sep 17 00:00:00 2001 From: Misha D Date: Tue, 3 Dec 2024 16:20:40 +0300 Subject: [PATCH 2/2] fix style, comments, unused code --- src/aux/utils.py | 16 ++++- src/base/dataset_converter.py | 120 +++++++++++++------------------- src/base/datasets_processing.py | 6 +- 3 files changed, 68 insertions(+), 74 deletions(-) diff --git a/src/aux/utils.py b/src/aux/utils.py index 6084701..2c53ad9 100644 --- a/src/aux/utils.py +++ b/src/aux/utils.py @@ -167,16 +167,26 @@ class tmp_dir(): """ Temporary create a directory near the given path. Remove it on exit. """ - def __init__(self, path: Path): + def __init__( + self, + path: Path + ): self.path = path from time import time self.tmp_dir = self.path.parent / (self.path.name + str(time())) - def __enter__(self): + def __enter__( + self + ) -> Path: self.tmp_dir.mkdir(parents=True) return self.tmp_dir - def __exit__(self, exception_type, exception_value, exception_traceback): + def __exit__( + self, + exception_type, + exception_value, + exception_traceback + ) -> None: import shutil try: shutil.rmtree(self.tmp_dir) diff --git a/src/base/dataset_converter.py b/src/base/dataset_converter.py index 9a71a5a..fb30f17 100644 --- a/src/base/dataset_converter.py +++ b/src/base/dataset_converter.py @@ -1,10 +1,12 @@ import json from pathlib import Path -from torch_geometric.utils import to_networkx, from_networkx +from typing import Union + import networkx as nx +from torch_geometric.data import Data +from torch_geometric.utils import from_networkx from base.datasets_processing import DatasetInfo -from base.ptg_datasets import is_graph_directed class DatasetConverter: @@ -86,7 +88,8 @@ def format_to_ij( with open(edge_attr_dir / attr, 'w') as f: json.dump([a[attr] for a in all_edge_attributes], f) - # FIXME add new attri to DatasetInfo? + # We do not add the extracted attributes to DatasetInfo, since it regulates whether + # attributes should be used. @staticmethod def networkx_to_format( @@ -124,6 +127,15 @@ def networkx_to_format( elif format == "s6": nx.write_sparse6(graph, graph_file) + # FORMATS THAT ARE NOT SUPPORTED: + # gexf, multiline_adjlist, weighted_edgelist + # # GRAPHML DOESN'T WORK WITH from_networkx() + # elif data_format == "graphml": + # # LEDA format is not supported as it stores edge attributes as strings + # elif data_format == "leda": + # # PAJEK format is not supported as it stores node attributes as strings + # elif data_format == "pajek": # Only works with graphs that have node labels + else: raise NotImplementedError @@ -172,22 +184,43 @@ def extract_attributes( return node_attributes, edge_attributes -def ptg_to_networkx(ptg_graph): - node_attrs = None - if ptg_graph.x != None: - node_attrs = ["x"] - edge_attrs = None - if ptg_graph.edge_attr != None: - edge_attrs = ["edge_attr"] - - # "graph_attrs" parameter is not working in this torch_geometric version - nx_graph = to_networkx(ptg_graph, node_attrs=node_attrs, edge_attrs=edge_attrs, - to_undirected=not is_graph_directed(ptg_graph)) - - return nx_graph +def read_nx_graph( + data_format: str, + path: Union[Path, str], + **kwargs +) -> nx.Graph: + # FORMATS THAT ARE NOT SUPPORTED: + # gexf, multiline_adjlist, weighted_edgelist + if data_format == "adjlist": # This format does not store graph or node attributes. + return nx.read_adjlist(path, **kwargs) + elif data_format == "edgelist": + return nx.read_edgelist(path, **kwargs) + elif data_format == "gml": # Only works with graphs that have node, edge attributes + return nx.read_gml(path) + # # GRAPHML DOESN'T WORK WITH from_networkx() + # elif data_format == "graphml": + # return nx.read_graphml(path, **kwargs) + # # LEDA format is not supported as it stores edge attributes as strings + # elif data_format == "leda": + # return nx.read_leda(path, **kwargs) + elif data_format == "g6": + return nx.read_graph6(path) + elif data_format == "s6": + return nx.read_sparse6(path) + # # PAJEK format is not supported as it stores node attributes as strings + # elif data_format == "pajek": # Only works with graphs that have node labels + # return nx.read_pajek(path, **kwargs) + else: + raise RuntimeError("the READING format is NOT SUPPORTED!!!") -def networkx_to_ptg(nx_graph): +def networkx_to_ptg( + nx_graph: nx.Graph +) -> Data: + """ + Convert networkx graph to a PTG Data. + Nodes and edges attributes, that numeric, are concatenated. + """ node_attribute_names = set() edge_attribute_names = set() @@ -224,59 +257,6 @@ def networkx_to_ptg(nx_graph): return ptg_graph -def read_nx_graph(data_format, path, **kwargs): - # FORMATS THAT ARE NOT SUPPORTED: - # gexf, multiline_adjlist, weighted_edgelist - if data_format == "adjlist": # This format does not store graph or node attributes. - return nx.read_adjlist(path, **kwargs) - elif data_format == "edgelist": - return nx.read_edgelist(path, **kwargs) - elif data_format == "gml": # Only works with graphs that have node, edge attributes - return nx.read_gml(path) - # # GRAPHML DOESN'T WORK WITH from_networkx() - # elif data_format == "graphml": - # return nx.read_graphml(path, **kwargs) - # # LEDA format is not supported as it stores edge attributes as strings - # elif data_format == "leda": - # return nx.read_leda(path, **kwargs) - elif data_format == "g6": - return nx.read_graph6(path) - elif data_format == "s6": - return nx.read_sparse6(path) - # # PAJEK format is not supported as it stores node attributes as strings - # elif data_format == "pajek": # Only works with graphs that have node labels - # return nx.read_pajek(path, **kwargs) - else: - raise RuntimeError("the READING format is NOT SUPPORTED!!!") - - -def write_nx_graph(graph, data_format, path): - if data_format == "adjlist": - return nx.write_adjlist(graph, path) - # elif data_format == "multiline_adjlist": - # return nx.write_multiline_adjlist(graph, path) - elif data_format == "edgelist": - return nx.write_edgelist(graph, path) - # elif data_format == "weighted_edgelist": - # return nx.write_weighted_edgelist(graph, path) - # elif data_format == "gexf": - # return nx.write_gexf(graph, path) - elif data_format == "gml": - return nx.write_gml(graph, path) - # elif data_format == "graphml": - # return nx.write_graphml(graph, path) - # elif data_format == "leda": - # return nx.write_leda(graph, path) - elif data_format == "g6": - return nx.write_graph6(graph, path) - elif data_format == "s6": - return nx.write_sparse6(graph, path) - # elif data_format == "pajek": - # return nx.write_pajek(graph, path) - else: - raise RuntimeError("the WRITING format is NOT SUPPORTED!!!") - - def example_single(): g = nx.Graph() g.add_node(11, a=0.4, b=100) diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index d0c9e11..3e9819f 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -1054,7 +1054,11 @@ def _register_torch_geometric( return gen_dataset -def merge_directories(source_dir, destination_dir, remove_source=False): +def merge_directories( + source_dir: Union[Path, str], + destination_dir: Union[Path, str], + remove_source: bool = False +): """ Merge source directory into destination directory, replacing existing files.