Skip to content

Commit

Permalink
make better files in base
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 21, 2024
1 parent 45220f9 commit c73763a
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 94 deletions.
75 changes: 58 additions & 17 deletions src/base/custom_datasets.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import json
import os
from pathlib import Path
from typing import Union

import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset

from aux.declaration import Declare
from base.datasets_processing import GeneralDataset, DatasetInfo
from aux.configs import DatasetConfig, DatasetVarConfig
from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern
from base.ptg_datasets import LocalDataset


class CustomDataset(GeneralDataset):
class CustomDataset(
GeneralDataset
):
""" User-defined dataset in 'ij' format.
"""
def __init__(self, dataset_config: DatasetConfig):
def __init__(
self,
dataset_config: Union[ConfigPattern, DatasetConfig]
):
"""
Args:
dataset_config: DatasetConfig dict from frontend
Expand All @@ -27,31 +34,44 @@ def __init__(self, dataset_config: DatasetConfig):
self.edge_index = None

@property
def node_attributes_dir(self):
def node_attributes_dir(
self
):
""" Path to dir with node attributes. """
return self.root_dir / 'raw' / (self.name + '.node_attributes')

@property
def edge_attributes_dir(self):
def edge_attributes_dir(
self
):
""" Path to dir with edge attributes. """
return self.root_dir / 'raw' / (self.name + '.edge_attributes')

@property
def labels_dir(self):
def labels_dir(
self
):
""" Path to dir with labels. """
return self.root_dir / 'raw' / (self.name + '.labels')

@property
def edges_path(self):
def edges_path(
self
):
""" Path to file with edge list. """
return self.root_dir / 'raw' / (self.name + '.ij')

@property
def edge_index_path(self):
def edge_index_path(
self
):
""" Path to dir with labels. """
return self.root_dir / 'raw' / (self.name + '.edge_index')

def build(self, dataset_var_config: DatasetVarConfig):
def build(
self,
dataset_var_config: Union[ConfigPattern, DatasetVarConfig]
) -> None:
""" Build ptg dataset based on dataset_var_config and create DatasetVarData.
"""
if dataset_var_config == self.dataset_var_config:
Expand All @@ -62,7 +82,10 @@ def build(self, dataset_var_config: DatasetVarConfig):
self.dataset_var_config = dataset_var_config
self.dataset = LocalDataset(self.results_dir, process_func=self._create_ptg)

def _compute_stat(self, stat):
def _compute_stat(
self,
stat: str
) -> dict:
""" Compute some additional stats
"""
if stat == "attr_corr":
Expand Down Expand Up @@ -123,7 +146,9 @@ def _compute_stat(self, stat):
else:
return super()._compute_stat(stat)

def _compute_dataset_data(self):
def _compute_dataset_data(
self
) -> None:
""" Get DatasetData for debug graph
Structure according to https://docs.google.com/spreadsheets/d/1fNI3sneeGoOFyIZP_spEjjD-7JX2jNl_P8CQrA4HZiI/edit#gid=1096434224
"""
Expand Down Expand Up @@ -272,7 +297,9 @@ def _compute_dataset_data(self):
# if self.info.name == "":
# self.dataset_data['info']['name'] = '/'.join(self.dataset_config.full_name())

def _create_ptg(self):
def _create_ptg(
self
) -> None:
""" Create PTG Dataset and save tensors
"""
if self.edge_index is None:
Expand All @@ -295,7 +322,10 @@ def _create_ptg(self):
self.results_dir.mkdir(exist_ok=True, parents=True)
torch.save(InMemoryDataset.collate(data_list), self.results_dir / 'data.pt')

def _iter_nodes(self, graph: int = None):
def _iter_nodes(
self,
graph: int = None
) -> None:
""" Iterate over nodes according to mapping. Yields pairs of (node_index, original_id)
"""
# offset = sum(self.info.nodes[:graph]) if self.is_multi() else 0
Expand All @@ -308,7 +338,10 @@ def _iter_nodes(self, graph: int = None):
for n in range(self.info.nodes[graph or 0]):
yield offset+n, str(n)

def _labeling_tensor(self, g_ix=None) -> list:
def _labeling_tensor(
self,
g_ix=None
) -> list:
""" Returns list of labels (not tensors) """
y = []
# Read labels
Expand All @@ -330,21 +363,29 @@ def _labeling_tensor(self, g_ix=None) -> list:

return y

def _feature_tensor(self, g_ix=None) -> list:
def _feature_tensor(
self,
g_ix=None
) -> list:
""" Returns list of features (not tensors) for graph g_ix.
"""
features = self.dataset_var_config.features # dict about attributes construction
nodes_onehot = "str_g" in features and features["str_g"] == "one_hot"

# Read attributes
def one_hot(x, values):
def one_hot(
x: int,
values: list
) -> list:
res = [0] * len(values)
for ix, v in enumerate(values):
if x == v:
res[ix] = 1
return res

def as_is(x):
def as_is(
x
) -> list:
return x if isinstance(x, list) else [x]

# TODO other encoding types from Kirill
Expand Down
Loading

0 comments on commit c73763a

Please sign in to comment.