Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Aug 19, 2024
2 parents 06d97d5 + 27616d9 commit 2707e33
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 20 deletions.
32 changes: 30 additions & 2 deletions src/uni2ts/data/builder/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,39 @@

# TODO: Add __repr__
class DatasetBuilder(abc.ABC):
"""
Base class for DatasetBuilders.
"""

@abc.abstractmethod
def build_dataset(self, *args, **kwargs): ...
def build_dataset(self, *args, **kwargs):
"""
Builds the dataset into the required file format.
"""
...

@abc.abstractmethod
def load_dataset(
self, transform_map: dict[Any, Callable[..., Transformation]]
) -> Dataset: ...
) -> Dataset:
"""
Load the dataset.
:param transform_map: a map which returns the required dataset transformations to be applied
:return: the dataset ready for training
"""
...


class ConcatDatasetBuilder(DatasetBuilder):
"""
Concatenates DatasetBuilders such that they can be loaded together.
"""

def __init__(self, *builders: DatasetBuilder):
"""
:param builders: DatasetBuilders to be concatenated together.
"""
super().__init__()
assert len(builders) > 0, "Must provide at least one builder to ConcatBuilder"
assert all(
Expand All @@ -49,6 +71,12 @@ def build_dataset(self):
def load_dataset(
self, transform_map: dict[Any, Callable[..., Transformation]]
) -> ConcatDataset:
"""
Loads all builders with ConcatDataset.
:param transform_map: a map which returns the required dataset transformations to be applied
:return: the dataset ready for training
"""
return ConcatDataset(
[builder.load_dataset(transform_map) for builder in self.builders]
)
27 changes: 27 additions & 0 deletions src/uni2ts/data/builder/lotsa_v1/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@

@abstract_class_property("dataset_list", "dataset_type_map", "dataset_load_func_map")
class LOTSADatasetBuilder(DatasetBuilder, abc.ABC):
"""
Base class for LOTSA dataset builders.
LOTSA datasets are backed by Hugging Face datasets, and use the HuggingFaceDatasetIndexer for fast indexing.
:attribute dataset_list: list of dataset names belonging to the DatasetBuilder class
:attribute dataset_type_map: map dataset names to TimeSeriesDataset
:attribute dataset_load_func_map: map dataset names to transform_map
:attribute uniform: whether all datasets in the dataset_list have uniform series length
"""

dataset_list: list[str] = NotImplemented
dataset_type_map: dict[str, type[TimeSeriesDataset]] = NotImplemented
dataset_load_func_map: dict[str, Callable[..., TimeSeriesDataset]] = NotImplemented
Expand All @@ -43,6 +53,12 @@ def __init__(
sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE,
storage_path: Path = env.LOTSA_V1_PATH,
):
"""
:param datasets: list of datasets to load
:param weight_map: map dataset names to dataset_weight argument for datasets
:param sample_time_series: how to sample time series from the datasets
:param storage_path: directory to which data is stored
"""
assert all(
dataset in self.dataset_list for dataset in datasets
), f"Invalid datasets {set(datasets).difference(self.dataset_list)}, must be one of {self.dataset_list}"
Expand All @@ -55,6 +71,9 @@ def __init__(
def load_dataset(
self, transform_map: dict[str | type, Callable[..., Transformation]]
) -> Dataset:
"""
Loads all datasets in dataset_list
"""
datasets = [
self.dataset_load_func_map[dataset](
HuggingFaceDatasetIndexer(
Expand All @@ -73,6 +92,14 @@ def _get_transform(
transform_map: dict[str | type, Callable[..., Transformation]],
dataset: str,
) -> Transformation:
"""
Retrieves the Transformation for a given dataset from the transform_map, with the following priority:
1. dataset name
2. dataset type
3. falls back to the default transform if a defaultdict is provided
4. falls back to a transform named `default` in the map
5. falls back to identity transform
"""
if dataset in transform_map:
transform = transform_map[dataset]
elif (dataset_type := self.dataset_type_map[dataset]) in transform_map:
Expand Down
52 changes: 52 additions & 0 deletions src/uni2ts/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@


class SampleTimeSeriesType(Enum):
"""
How to sample from the dataset.
- none: do not sample, return the current index.
- uniform: each time series sampled with equal probability
- proportional: each time series sampled with probability proportional to it's length
"""

NONE = "none"
UNIFORM = "uniform"
PROPORTIONAL = "proportional"
Expand All @@ -47,6 +54,12 @@ def __init__(
sample_time_series: SampleTimeSeriesType = SampleTimeSeriesType.NONE,
dataset_weight: float = 1.0,
):
"""
:param indexer: Underlying Indexer object
:param transform: Transformation to apply to time series
:param sample_time_series: defines how a time series is obtained from the dataset
:param dataset_weight: multiplicative factor to apply to dataset size
"""
self.indexer = indexer
self.transform = transform
self.sample_time_series = sample_time_series
Expand All @@ -62,6 +75,11 @@ def __init__(
raise ValueError(f"Unknown sample type {sample_time_series}")

def __getitem__(self, idx: int) -> dict[str, FlattenedData]:
"""
Obtain a time series from the dataset, flatten
:param idx: index of time series to retrieve. if sample_time_series is specified, this will be ignored.
:return: transformed time series data
"""
if idx < 0 or idx >= len(self):
raise IndexError(
f"Index {idx} out of range for dataset of length {len(self)}"
Expand All @@ -74,16 +92,28 @@ def __getitem__(self, idx: int) -> dict[str, FlattenedData]:

@property
def num_ts(self) -> int:
"""
Get the number of time series in the dataset
"""
return len(self.indexer)

def __len__(self) -> int:
"""
Length is the number of time series multiplied by dataset_weight
"""
return int(np.ceil(self.num_ts * self.dataset_weight))

def _get_data(self, idx: int) -> dict[str, Data | BatchedData]:
"""
Obtains time series from Indexer object
"""
return self.indexer[idx % self.num_ts]

@staticmethod
def _flatten_data(data: dict[str, Data]) -> dict[str, FlattenedData]:
"""
Convert time series type data into a list of univariate time series
"""
return {
k: (
[v]
Expand All @@ -95,6 +125,11 @@ def _flatten_data(data: dict[str, Data]) -> dict[str, FlattenedData]:


class MultiSampleTimeSeriesDataset(TimeSeriesDataset):
"""
Samples multiple time series and stacks them into a single time series.
Underlying dataset should have aligned time series, meaning same start and end dates.
"""

def __init__(
self,
indexer: Indexer[dict[str, Any]],
Expand All @@ -105,6 +140,15 @@ def __init__(
dataset_weight: float = 1.0,
sampler: Sampler = get_sampler("beta_binomial", a=2, b=5),
):
"""
:param indexer: Underlying Indexer object
:param transform: Transformation to apply to time series
:param max_ts: maximum number of time series that can be stacked together
:param combine_fields: fields which should be stacked
:param sample_time_series: defines how a time series is obtained from the dataset
:param dataset_weight: multiplicative factor to apply to dataset size
:param sampler: how to sample the other time series
"""
super().__init__(indexer, transform, sample_time_series, dataset_weight)
self.max_ts = max_ts
self.combine_fields = combine_fields
Expand Down Expand Up @@ -139,12 +183,20 @@ def _flatten_data(


class EvalDataset(TimeSeriesDataset):
"""
Dataset class for validation.
Should be used in conjunction with Eval transformations.
"""

def __init__(
self,
windows: int,
indexer: Indexer[dict[str, Any]],
transform: Transformation,
):
"""
:param windows: number of windows to perform evaluation on
"""
super().__init__(
indexer,
transform,
Expand Down
34 changes: 34 additions & 0 deletions src/uni2ts/data/indexer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,27 @@


class Indexer(abc.ABC, Sequence):
"""
Base class for all Indexers.
An Indexer is responsible for extracting data from an underlying file format.
"""

def __init__(self, uniform: bool = False):
"""
:param uniform: whether the underlying data has uniform length
"""
self.uniform = uniform

def check_index(self, idx: int | slice | Iterable[int]):
"""
Check the validity of a given index.
:param idx: index to check
:return: None
:raises IndexError: if idx is out of bounds
:raises NotImplementedError: if idx is not a valid type
"""
if isinstance(idx, int):
if idx < 0 or idx >= len(self):
raise IndexError(f"Index {idx} out of bounds for length {len(self)}")
Expand All @@ -48,6 +65,12 @@ def check_index(self, idx: int | slice | Iterable[int]):
def __getitem__(
self, idx: int | slice | Iterable[int]
) -> dict[str, Data | BatchedData]:
"""
Retrive the data from the underlying storage in dictionary format.
:param idx: index to retrieve
:return: underlying data with given index
"""
self.check_index(idx)

if isinstance(idx, int):
Expand All @@ -72,9 +95,20 @@ def _getitem_int(self, idx: int) -> dict[str, Data]: ...
def _getitem_iterable(self, idx: Iterable[int]) -> dict[str, BatchedData]: ...

def get_uniform_probabilities(self) -> np.ndarray:
"""
Obtains uniform probability distribution over all time series.
:return: uniform probability distribution
"""
return np.ones(len(self)) / len(self)

def get_proportional_probabilities(self, field: str = "target") -> np.ndarray:
"""
Obtain proportion of each time series based on number of time steps.
:param field: field name to measure time series length
:return: proportional probabilities
"""
if self.uniform:
return self.get_uniform_probabilities()

Expand Down
16 changes: 16 additions & 0 deletions src/uni2ts/data/indexer/hf_dataset_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@


class HuggingFaceDatasetIndexer(Indexer):
"""
Indexer for Hugging Face Datasets
"""

def __init__(self, dataset: Dataset, uniform: bool = False):
"""
:param dataset: underlying Hugging Face Dataset
:param uniform: whether the underlying data has uniform length
"""
super().__init__(uniform=uniform)
self.dataset = dataset
self.features = dict(self.dataset.features)
Expand Down Expand Up @@ -109,6 +117,14 @@ def _pa_column_to_numpy(
return array

def get_proportional_probabilities(self, field: str = "target") -> np.ndarray:
"""
Obtain proportion of each time series based on number of time steps.
Leverages pyarrow.compute for fast implementation.
:param field: field name to measure time series length
:return: proportional probabilities
"""

if self.uniform:
return self.get_uniform_probabilities()

Expand Down
Loading

0 comments on commit 2707e33

Please sign in to comment.