Skip to content

Commit

Permalink
Manually set 'freq' for specific datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Aug 19, 2024
1 parent 27616d9 commit acfbdce
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions src/uni2ts/data/builder/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import argparse
from collections import defaultdict
from dataclasses import dataclass
from itertools import product
from pathlib import Path
Expand All @@ -26,15 +27,24 @@

from uni2ts.common.env import env
from uni2ts.common.typing import GenFunc
from uni2ts.data.builder._base import DatasetBuilder
from uni2ts.data.dataset import EvalDataset, SampleTimeSeriesType, TimeSeriesDataset
from uni2ts.data.indexer import HuggingFaceDatasetIndexer
from uni2ts.transform import Transformation

from ._base import DatasetBuilder
# Manually set the freq of the datasets whose freq can be inferred automatically. Default freq is H.
freq_dict = defaultdict(
lambda: "H",
{
"weather": "10T",
"weather_eval": "10T",
},
)


def _from_long_dataframe(
df: pd.DataFrame,
dataset: str,
offset: Optional[int] = None,
date_offset: Optional[pd.Timestamp] = None,
) -> tuple[GenFunc, Features]:
Expand All @@ -50,7 +60,11 @@ def example_gen_func() -> Generator[dict[str, Any], None, None]:
yield {
"target": item_df.to_numpy(),
"start": item_df.index[0],
"freq": pd.infer_freq(item_df.index),
"freq": (
pd.infer_freq(df.index)
if pd.infer_freq(df.index) is not None
else freq_dict[dataset]
),
"item_id": item_id,
}

Expand All @@ -68,6 +82,7 @@ def example_gen_func() -> Generator[dict[str, Any], None, None]:

def _from_wide_dataframe(
df: pd.DataFrame,
dataset: str,
offset: Optional[int] = None,
date_offset: Optional[pd.Timestamp] = None,
) -> tuple[GenFunc, Features]:
Expand All @@ -83,7 +98,11 @@ def example_gen_func() -> Generator[dict[str, Any], None, None]:
yield {
"target": df.iloc[:, i].to_numpy(),
"start": df.index[0],
"freq": pd.infer_freq(df.index),
"freq": (
pd.infer_freq(df.index)
if pd.infer_freq(df.index) is not None
else freq_dict[dataset]
),
"item_id": f"item_{i}",
}

Expand All @@ -101,6 +120,7 @@ def example_gen_func() -> Generator[dict[str, Any], None, None]:

def _from_wide_dataframe_multivariate(
df: pd.DataFrame,
dataset: str,
offset: Optional[int] = None,
date_offset: Optional[pd.Timestamp] = None,
) -> tuple[GenFunc, Features]:
Expand All @@ -113,7 +133,11 @@ def example_gen_func() -> Generator[dict[str, Any], None, None]:
yield {
"target": df.to_numpy().T,
"start": df.index[0],
"freq": pd.infer_freq(df.index),
"freq": (
pd.infer_freq(df.index)
if pd.infer_freq(df.index) is not None
else freq_dict[dataset]
),
"item_id": "item_0",
}

Expand Down Expand Up @@ -166,7 +190,7 @@ def build_dataset(
)

example_gen_func, features = _from_dataframe(
df, offset=offset, date_offset=date_offset
df, dataset=self.dataset, offset=offset, date_offset=date_offset
)
hf_dataset = datasets.Dataset.from_generator(
example_gen_func, features=features
Expand Down Expand Up @@ -218,7 +242,7 @@ def build_dataset(self, file: Path, dataset_type: str):
" Valid options are 'long', 'wide', and 'wide_multivariate'."
)

example_gen_func, features = _from_dataframe(df)
example_gen_func, features = _from_dataframe(df, dataset=self.dataset)
hf_dataset = datasets.Dataset.from_generator(
example_gen_func, features=features
)
Expand Down

0 comments on commit acfbdce

Please sign in to comment.