Skip to content

Commit

Permalink
style fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
mishadr committed Dec 6, 2024
1 parent b70c314 commit 380a752
Show file tree
Hide file tree
Showing 14 changed files with 568 additions and 231 deletions.
2 changes: 1 addition & 1 deletion src/aux/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def set_defaults_config_pattern_info(

def to_json(
self
):
) -> dict:
""" Special method which allows to use json.dumps() on Config object """
return self.to_dict()

Expand Down
12 changes: 6 additions & 6 deletions src/base/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,41 @@ def __init__(
@property
def node_attributes_dir(
self
):
) -> Path:
""" Path to dir with node attributes. """
return self.root_dir / 'raw' / (self.name + '.node_attributes')

@property
def edge_attributes_dir(
self
):
) -> Path:
""" Path to dir with edge attributes. """
return self.root_dir / 'raw' / (self.name + '.edge_attributes')

@property
def labels_dir(
self
):
) -> Path:
""" Path to dir with labels. """
return self.root_dir / 'raw' / (self.name + '.labels')

@property
def edges_path(
self
):
) -> Path:
""" Path to file with edge list. """
return self.root_dir / 'raw' / (self.name + '.ij')

@property
def edge_index_path(
self
):
) -> Path:
""" Path to dir with labels. """
return self.root_dir / 'raw' / (self.name + '.edge_index')

def check_validity(
self
):
) -> None:
""" Check that dataset files (graph and attributes) are valid and consistent with .info.
"""
# Assuming info is OK
Expand Down
14 changes: 7 additions & 7 deletions src/base/dataset_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def set(
self,
stat: str,
value: Union[int, float, dict, str]
):
) -> None:
""" Set statistics to a specified value and save to file.
"""
assert stat in DatasetStats.all_stats
Expand All @@ -129,7 +129,7 @@ def set(
def remove(
self,
stat: str
):
) -> None:
""" Remove statistics from dict and file.
"""
if stat in self.stats:
Expand All @@ -140,15 +140,15 @@ def remove(

def clear_all_stats(
self
):
) -> None:
""" Remove all stats. E.g. the graph has changed.
"""
for s in DatasetStats.all_stats:
self.remove(s)

def update_var_config(
self
):
) -> None:
""" Remove var stats from dict since dataset config has changed.
"""
for s in DatasetStats.var_stats:
Expand All @@ -158,7 +158,7 @@ def update_var_config(
def _compute(
self,
stat: str
):
) -> None:
""" Compute statistics for a single graph.
Result could be: a number, a string, a distribution, a dict of ones.
"""
Expand Down Expand Up @@ -255,7 +255,7 @@ def _compute(
def _compute_multi(
self,
stat: str
):
) -> None:
""" Compute statistics for a multiple-graphs dataset.
Result could be: a number, a string, a distribution, a dict of ones.
"""
Expand Down Expand Up @@ -286,7 +286,7 @@ def _compute_multi(

def list_to_hist(
a_list: list
):
) -> dict:
""" Convert a list of integers/floats to a frequency histogram, return it as a dict
"""
return {k: v for k, v in Counter(a_list).most_common()}
36 changes: 18 additions & 18 deletions src/base/datasets_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch_geometric
from torch import default_generator, randperm
from torch_geometric.data import Dataset, InMemoryDataset
from torch_geometric.data import Dataset, InMemoryDataset, Data

from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern
from aux.custom_decorators import timing_decorator
Expand Down Expand Up @@ -123,7 +123,7 @@ def save(
@staticmethod
def induce(
dataset: Dataset
):
) -> object:
""" Induce metainfo from a given PTG dataset.
"""
res = DatasetInfo()
Expand All @@ -144,7 +144,7 @@ def induce(
@staticmethod
def read(
path: Union[str, Path]
):
) -> object:
""" Read info from a file. """
with path.open('r') as f:
a_dict = json.load(f)
Expand Down Expand Up @@ -346,62 +346,62 @@ def __init__(
@property
def root_dir(
self
):
) -> Path:
""" Dataset root directory with folders 'raw' and 'prepared'. """
# FIXME Misha, dataset_prepared_dir return path and files_paths not only path
return Declare.dataset_root_dir(self.dataset_config)[0]

@property
def results_dir(
self
):
) -> Path:
""" Path to 'prepared/../' folder where tensor data is stored. """
# FIXME Misha, dataset_prepared_dir return path and files_paths not only path
return Path(Declare.dataset_prepared_dir(self.dataset_config, self.dataset_var_config)[0])

@property
def raw_dir(
self
):
) -> Path:
""" Path to 'raw/' folder where raw data is stored. """
return self.root_dir / 'raw'

@property
def api_path(
self
):
) -> Path:
""" Path to '.api' file. Could be not present. """
return self.root_dir / '.api'

@property
def info_path(
self
):
) -> Path:
""" Path to '.info' file. """
return self.root_dir / 'raw' / '.info'

@property
def data(
self
):
) -> Data:
return self.dataset._data

@property
def num_classes(
self
):
) -> int:
return self.dataset.num_classes

@property
def num_node_features(
self
):
) -> int:
return self.dataset.num_node_features

@property
def labels(
self
):
) -> torch.Tensor:
if self._labels is None:
# NOTE: this is a copy from torch_geometric.data.dataset v=2.3.1
from torch_geometric.data.dataset import _get_flattened_data_list
Expand All @@ -428,7 +428,7 @@ def is_multi(
def build(
self,
dataset_var_config: Union[ConfigPattern, DatasetVarConfig]
):
) -> None:
""" Create node feature tensors from attributes based on dataset_var_config.
"""
raise NotImplementedError()
Expand Down Expand Up @@ -589,23 +589,23 @@ def _compute_dataset_var_data(
def get_stat(
self,
stat: str
):
) -> Union[int, float, dict, str]:
""" Get statistics.
"""
return self.stats.get(stat)

def _compute_stat(
self,
stat: str
):
) -> None:
""" Compute a non-standard statistics.
"""
# Should bw defined in a subclass
# Should be defined in a subclass
raise NotImplementedError()

def is_one_hot_able(
self
):
) -> bool:
""" Return whether features are 1-hot encodings. If yes, nodes can be colored.
"""
assert self.dataset_var_config
Expand Down Expand Up @@ -942,7 +942,7 @@ def merge_directories(
source_dir: Union[Path, str],
destination_dir: Union[Path, str],
remove_source: bool = False
):
) -> None:
"""
Merge source directory into destination directory, replacing existing files.
Expand Down
4 changes: 2 additions & 2 deletions src/base/ptg_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,12 @@ def __init__(
@property
def processed_file_names(
self
):
) -> str:
return 'data.pt'

def process(
self
):
) -> None:
raise RuntimeError("Dataset is supposed to be processed and saved earlier.")
# torch.save(self.collate(self.data_list), self.processed_paths[0])

Expand Down
4 changes: 2 additions & 2 deletions src/base/vk_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AttrInfo:

@staticmethod
def vk_attr(
):
) -> dict:
vk_dict = {
('age',): list(range(0, len(AGE_GROUPS) + 1)),
('sex',): [1, 2],
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(

def _compute_dataset_data(
self
):
) -> None:
""" Get DatasetData for VK graph
"""
super()._compute_dataset_data()
Expand Down
Loading

0 comments on commit 380a752

Please sign in to comment.