Skip to content

Commit

Permalink
Black & Isort
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Oct 22, 2024
1 parent 326f14d commit 7f43cc1
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 78 deletions.
2 changes: 1 addition & 1 deletion project/multi_scale/finetune/small/ettm1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model=moirai_1.1_R_small
cp=conf/multi_scale/finetune
exp_name=lsf
cl=3000
ft_pattern=full
ft_pattern=full_0

data=ettm1
ps=128
Expand Down
4 changes: 2 additions & 2 deletions project/multi_scale/finetune/small/ettm2.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#!/bin/bash

export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0;
export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=1;

model=moirai_1.1_R_small
cp=conf/multi_scale/finetune
exp_name=lsf
cl=3000
ft_pattern=full
ft_pattern=full_0

data=ettm2
ps=64
Expand Down
27 changes: 21 additions & 6 deletions src/uni2ts/data/builder/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

from uni2ts.common.env import env
from uni2ts.common.typing import GenFunc

# from ._base import DatasetBuilder
from uni2ts.data.builder._base import DatasetBuilder
from uni2ts.data.dataset import (
EvalDataset,
FinetuneDataset,
Expand All @@ -35,8 +38,6 @@
from uni2ts.data.indexer import HuggingFaceDatasetIndexer
from uni2ts.transform import Transformation

# from ._base import DatasetBuilder
from uni2ts.data.builder._base import DatasetBuilder

def _from_long_dataframe(
df: pd.DataFrame,
Expand Down Expand Up @@ -212,8 +213,15 @@ def build_dataset(
df = pd.read_csv(file, index_col=0, parse_dates=True)

if normalize:
end = offset if offset is not None else len(
df[df.index <= date_offset].index) if date_offset is not None else len(df.index)
end = (
offset
if offset is not None
else (
len(df[df.index <= date_offset].index)
if date_offset is not None
else len(df.index)
)
)
df = self.scale(df, 0, end)

if dataset_type == "long":
Expand Down Expand Up @@ -293,8 +301,15 @@ def build_dataset(
df = pd.read_csv(file, index_col=0, parse_dates=True)

if normalize:
end = offset if offset is not None else len(
df[df.index <= date_offset].index) if date_offset is not None else len(df.index)
end = (
offset
if offset is not None
else (
len(df[df.index <= date_offset].index)
if date_offset is not None
else len(df.index)
)
)
df = self.scale(df, 0, end)

if dataset_type == "long":
Expand Down
14 changes: 12 additions & 2 deletions src/uni2ts/model/lsf_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,12 @@ def default_train_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down Expand Up @@ -632,7 +637,12 @@ def default_val_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down
22 changes: 18 additions & 4 deletions src/uni2ts/model/lsf_moirai_point/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,12 @@ def __init__(
self.criterion = torch.nn.MSELoss()

def replace_forecast_head(self):
seq_len = math.ceil(self.context_length / self.patch_size) + math.ceil(self.prediction_length / self.patch_size)
self.module.replace_forecast_head(seq_len=seq_len, pred_len=self.prediction_length)
seq_len = math.ceil(self.context_length / self.patch_size) + math.ceil(
self.prediction_length / self.patch_size
)
self.module.replace_forecast_head(
seq_len=seq_len, pred_len=self.prediction_length
)

def forward(
self,
Expand Down Expand Up @@ -469,7 +473,12 @@ def default_train_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down Expand Up @@ -596,7 +605,12 @@ def default_val_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down
8 changes: 6 additions & 2 deletions src/uni2ts/model/lsf_moirai_point/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def __init__(
self.replace_forecast_head()

def replace_forecast_head(self):
seq_len = math.ceil(self.hparams.context_length / self.hparams.patch_size) + math.ceil(self.hparams.prediction_length / self.hparams.patch_size)
self.module.replace_forecast_head(seq_len=seq_len, pred_len=self.hparams.prediction_length)
seq_len = math.ceil(
self.hparams.context_length / self.hparams.patch_size
) + math.ceil(self.hparams.prediction_length / self.hparams.patch_size)
self.module.replace_forecast_head(
seq_len=seq_len, pred_len=self.hparams.prediction_length
)

def create_predictor(
self,
Expand Down
21 changes: 18 additions & 3 deletions src/uni2ts/model/moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,15 +512,25 @@ def default_train_transform():
min_mask_ratio=self.hparams.min_mask_ratio,
max_mask_ratio=self.hparams.max_mask_ratio,
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
)
if self.context_length is None or self.prediction_length is None
else MaskedPredictionGivenFixedConfig(
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down Expand Up @@ -637,7 +647,12 @@ def default_val_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down
5 changes: 3 additions & 2 deletions src/uni2ts/model/multi_scale_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
patch_size: Optional[int] = None,
finetune_pattern: str | list[str] = "full",
num_new_scales: int = 1,
ds_factor: int = 2
ds_factor: int = 2,
):
super().__init__()
self.save_hyperparameters(ignore=["module"])
Expand Down Expand Up @@ -509,7 +509,8 @@ def default_train_transform(
)
+ AddSampleIndex(
fields=("target",),
optional_fields=("past_feat_dynamic_real",) + self.new_scales_target_fields,
optional_fields=("past_feat_dynamic_real",)
+ self.new_scales_target_fields,
sample_id_field="sample_id",
expected_ndim=3,
collection_type=dict,
Expand Down
10 changes: 7 additions & 3 deletions src/uni2ts/model/multi_scale_moirai/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
num_samples: int = 100,
pretrained_checkpoint_path: str = None,
num_new_scales: int = 1,
ds_factor: int = 2
ds_factor: int = 2,
):
assert (module is not None) or (
module_kwargs is not None
Expand Down Expand Up @@ -766,7 +766,9 @@ def _convert(
# Downsample
past_target = self._downsample(past_target, left=True)
past_observed_target = self._downsample(past_observed_target, left=True)
past_is_pad = self._downsample(past_is_pad.bool(), ds_factor=self.ds_factor, left=False).int()
past_is_pad = self._downsample(
past_is_pad.bool(), ds_factor=self.ds_factor, left=False
).int()
context_length = math.ceil(context_length / 2)

target.extend(
Expand Down Expand Up @@ -882,7 +884,9 @@ def _convert(
prediction_mask,
)

def _downsample(self, arr: torch.Tensor, ds_factor: int = 2, left: bool = True) -> torch.Tensor:
def _downsample(
self, arr: torch.Tensor, ds_factor: int = 2, left: bool = True
) -> torch.Tensor:
# Check if the input tensor is 2D (bs, time) or 3D (*bs, time, feature)
if arr.ndim == 2:
# 2D case: arr is (bs, time) without feature dimension
Expand Down
6 changes: 4 additions & 2 deletions src/uni2ts/model/multi_scale_moirai/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@

from uni2ts.common.torch_util import mask_fill, packed_attention_mask
from uni2ts.distribution import DistributionOutput
from uni2ts.module.multi_scale.transformer import TransformerEncoder
from uni2ts.module.norm import RMSNorm
from uni2ts.module.packed_scaler import PackedNOPScaler, PackedStdScaler
from uni2ts.module.position import (
BinaryAttentionBias,
QueryKeyProjection,
RotaryProjection,
)
from uni2ts.module.multi_scale.transformer import TransformerEncoder
from uni2ts.module.ts_embed import MultiInSizeLinear


Expand Down Expand Up @@ -124,7 +124,9 @@ def __init__(
activation=F.silu,
use_glu=True,
use_qk_norm=True,
var_attn_bias_layer=partial(BinaryAttentionBias), # ToDo: 这个var attn bias可以改
var_attn_bias_layer=partial(
BinaryAttentionBias
), # ToDo: 这个var attn bias可以改
time_qk_proj_layer=partial(
QueryKeyProjection,
proj_layer=RotaryProjection,
Expand Down
15 changes: 12 additions & 3 deletions src/uni2ts/model/seasonal_naive_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,12 @@ def default_train_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down Expand Up @@ -651,7 +656,12 @@ def default_val_transform(
+ EvalMaskedPrediction(
mask_length=math.ceil(prediction_length / patch_size),
target_field="target",
truncate_fields=("variate_id", "time_id", "observed_mask", "sample_id"),
truncate_fields=(
"variate_id",
"time_id",
"observed_mask",
"sample_id",
),
optional_truncate_fields=("past_feat_dynamic_real",),
prediction_mask_field="prediction_mask",
expected_ndim=3,
Expand Down Expand Up @@ -713,4 +723,3 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if name in self.trainable_params
}
return filtered_state

Loading

0 comments on commit 7f43cc1

Please sign in to comment.