From 078257df7b5efd87e0673787fd602a99bb460383 Mon Sep 17 00:00:00 2001 From: Qiao Zhongzheng Date: Tue, 15 Oct 2024 16:31:10 +0800 Subject: [PATCH] Change val data distance. Add scripts to new finetune. --- cli/conf/new_eval/default.yaml | 2 +- cli/conf/new_finetune/data/electricity.yaml | 6 +- cli/conf/new_finetune/data/etth1.yaml | 10 +- cli/conf/new_finetune/data/etth2.yaml | 6 +- cli/conf/new_finetune/data/ettm1.yaml | 6 +- cli/conf/new_finetune/data/ettm2.yaml | 6 +- cli/conf/new_finetune/data/weather.yaml | 6 +- cli/conf/new_finetune/default.yaml | 6 +- .../model/moirai_1.1_R_small.yaml | 4 +- .../new_finetune/val_data/electricity.yaml | 20 +- cli/conf/new_finetune/val_data/etth1.yaml | 16 +- cli/conf/new_finetune/val_data/etth2.yaml | 20 +- cli/conf/new_finetune/val_data/ettm1.yaml | 21 +- cli/conf/new_finetune/val_data/ettm2.yaml | 21 +- cli/conf/new_finetune/val_data/weather.yaml | 16 +- cli/train.py | 22 +- .../eval/small/run_multi.sh | 10 +- .../new_finetune/eval/small/electricity.sh | 44 ++++ project/new_finetune/eval/small/etth1.sh | 43 ++++ project/new_finetune/eval/small/etth2.sh | 44 ++++ project/new_finetune/eval/small/ettm1.sh | 44 ++++ project/new_finetune/eval/small/ettm2.sh | 45 ++++ project/new_finetune/eval/small/run_multi.sh | 8 + project/new_finetune/eval/small/weather.sh | 44 ++++ .../new_finetune/train/small/electricity.sh | 32 +++ project/new_finetune/train/small/etth1.sh | 32 +++ project/new_finetune/train/small/etth2.sh | 32 +++ project/new_finetune/train/small/ettm1.sh | 32 +++ project/new_finetune/train/small/ettm2.sh | 32 +++ project/new_finetune/train/small/run_multi.sh | 7 + project/new_finetune/train/small/weather.sh | 32 +++ src/uni2ts/data/builder/simple.py | 103 ++++++++-- src/uni2ts/data/dataset.py | 4 +- src/uni2ts/model/new_moirai/finetune.py | 21 +- src/uni2ts/module/attention.py | 193 +++++++++--------- src/uni2ts/module/position/attn_projection.py | 4 +- src/uni2ts/module/transformer.py | 2 +- src/uni2ts/transform/__init__.py | 6 +- src/uni2ts/transform/crop.py | 5 +- 39 files changed, 773 insertions(+), 234 deletions(-) create mode 100644 project/new_finetune/eval/small/electricity.sh create mode 100644 project/new_finetune/eval/small/etth1.sh create mode 100644 project/new_finetune/eval/small/etth2.sh create mode 100644 project/new_finetune/eval/small/ettm1.sh create mode 100644 project/new_finetune/eval/small/ettm2.sh create mode 100644 project/new_finetune/eval/small/run_multi.sh create mode 100644 project/new_finetune/eval/small/weather.sh create mode 100644 project/new_finetune/train/small/electricity.sh create mode 100644 project/new_finetune/train/small/etth1.sh create mode 100644 project/new_finetune/train/small/etth2.sh create mode 100644 project/new_finetune/train/small/ettm1.sh create mode 100644 project/new_finetune/train/small/ettm2.sh create mode 100644 project/new_finetune/train/small/run_multi.sh create mode 100644 project/new_finetune/train/small/weather.sh diff --git a/cli/conf/new_eval/default.yaml b/cli/conf/new_eval/default.yaml index 441c274..76bceb7 100644 --- a/cli/conf/new_eval/default.yaml +++ b/cli/conf/new_eval/default.yaml @@ -1,6 +1,6 @@ hydra: run: - dir: outputs/eval/${hydra:runtime.choices.model}/${exp_name}/${data.dataset_name}/${data.mode}/cl${model.context_length}_pl${data.prediction_length} + dir: outputs/new_eval/${hydra:runtime.choices.model}/${exp_name}/${data.dataset_name}/${data.mode}/cl${model.context_length}_pl${data.prediction_length} defaults: - model: ??? - data: ??? diff --git a/cli/conf/new_finetune/data/electricity.yaml b/cli/conf/new_finetune/data/electricity.yaml index 4fb675b..70ca032 100644 --- a/cli/conf/new_finetune/data/electricity.yaml +++ b/cli/conf/new_finetune/data/electricity.yaml @@ -1,2 +1,6 @@ -_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +_target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: electricity +train_length: 18412 +prediction_length: ??? +context_length: ??? +patch_size: ??? diff --git a/cli/conf/new_finetune/data/etth1.yaml b/cli/conf/new_finetune/data/etth1.yaml index 83d5408..f7235c4 100644 --- a/cli/conf/new_finetune/data/etth1.yaml +++ b/cli/conf/new_finetune/data/etth1.yaml @@ -1,10 +1,6 @@ -#_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder -#dataset: ETTh1 -#weight: 1000 - _target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: ETTh1 train_length: 8640 -prediction_length: 96 -context_length: 3000 -patch_size: 64 \ No newline at end of file +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/data/etth2.yaml b/cli/conf/new_finetune/data/etth2.yaml index 13d29d6..1dd47f7 100644 --- a/cli/conf/new_finetune/data/etth2.yaml +++ b/cli/conf/new_finetune/data/etth2.yaml @@ -1,2 +1,6 @@ -_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +_target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: ETTh2 +train_length: 8640 +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/data/ettm1.yaml b/cli/conf/new_finetune/data/ettm1.yaml index df066af..dbde79e 100644 --- a/cli/conf/new_finetune/data/ettm1.yaml +++ b/cli/conf/new_finetune/data/ettm1.yaml @@ -1,2 +1,6 @@ -_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +_target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: ETTm1 +train_length: 34560 +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/data/ettm2.yaml b/cli/conf/new_finetune/data/ettm2.yaml index 5ffbcc5..5c402f1 100644 --- a/cli/conf/new_finetune/data/ettm2.yaml +++ b/cli/conf/new_finetune/data/ettm2.yaml @@ -1,2 +1,6 @@ -_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +_target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: ETTm2 +train_length: 34560 +prediction_length: ??? +context_length: ??? +patch_size: ??? diff --git a/cli/conf/new_finetune/data/weather.yaml b/cli/conf/new_finetune/data/weather.yaml index 41d5b06..a6fa5fd 100644 --- a/cli/conf/new_finetune/data/weather.yaml +++ b/cli/conf/new_finetune/data/weather.yaml @@ -1,2 +1,6 @@ -_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +_target_: uni2ts.data.builder.simple.generate_finetune_builder dataset: weather +train_length: 36887 +prediction_length: ??? +context_length: ??? +patch_size: ??? diff --git a/cli/conf/new_finetune/default.yaml b/cli/conf/new_finetune/default.yaml index 6ad6292..cd9f88e 100644 --- a/cli/conf/new_finetune/default.yaml +++ b/cli/conf/new_finetune/default.yaml @@ -36,7 +36,7 @@ trainer: - _target_: uni2ts.callbacks.earlystop.WarmupEarlyStopping # lightning.pytorch.callbacks.EarlyStopping monitor: val/PackedNLLLoss min_delta: 0.0 - patience: 30 + patience: 3 # Set to a small value as now each epoch has many batches. mode: min strict: false verbose: true @@ -48,9 +48,9 @@ trainer: gradient_clip_algorithm: norm train_dataloader: _target_: uni2ts.data.loader.DataLoader - batch_size: 64 + batch_size: 512 # Can use a large batch size after disabling sequence packing. batch_size_factor: 2.0 - cycle: false # true + cycle: false # Set it as false to loop over all batches per epoch num_batches_per_epoch: null shuffle: true num_workers: 11 diff --git a/cli/conf/new_finetune/model/moirai_1.1_R_small.yaml b/cli/conf/new_finetune/model/moirai_1.1_R_small.yaml index 25f4e52..4c605d3 100644 --- a/cli/conf/new_finetune/model/moirai_1.1_R_small.yaml +++ b/cli/conf/new_finetune/model/moirai_1.1_R_small.yaml @@ -29,11 +29,11 @@ val_metric: - _target_: uni2ts.loss.packed.PackedMSELoss - _target_: uni2ts.loss.packed.PackedNRMSELoss normalize: absolute_target_squared -lr: 1e-4 +lr: 5e-7 # On ETT dataset, using 1e-6/5e-7 converge within 1-2 epochs. 1e-7 converge in tens of epochs weight_decay: 1e-1 beta1: 0.9 beta2: 0.98 -num_training_steps: null # ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_training_steps: null num_warmup_steps: 0 patch_size: null context_length: null diff --git a/cli/conf/new_finetune/val_data/electricity.yaml b/cli/conf/new_finetune/val_data/electricity.yaml index 61e8c7e..74981e0 100644 --- a/cli/conf/new_finetune/val_data/electricity.yaml +++ b/cli/conf/new_finetune/val_data/electricity.yaml @@ -1,13 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: electricity_eval - offset: 18412 # Same as _lsf_dataset.py - eval_length: 2630 # Same as _lsf_dataset.py, test_length=5260 - prediction_lengths: ??? - context_lengths: ??? - patch_sizes: ??? - -# prediction_lengths: [96, 192, 336, 720] -# context_lengths: [3000] -# patch_sizes: [32, 64] # freq='h' \ No newline at end of file +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: electricity_eval +offset: 18412 # Same as _lsf_dataset.py +eval_length: 2630 # Same as _lsf_dataset.py, test_length=5260 +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/val_data/etth1.yaml b/cli/conf/new_finetune/val_data/etth1.yaml index 752129a..a409cde 100644 --- a/cli/conf/new_finetune/val_data/etth1.yaml +++ b/cli/conf/new_finetune/val_data/etth1.yaml @@ -1,9 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: ETTh1_eval - offset: 8640 - eval_length: 2880 - prediction_lengths: [96, 192, 336, 720] - context_lengths: [1000, 2000, 3000, 4000, 5000] - patch_sizes: [32, 64] \ No newline at end of file +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTh1_eval +offset: 8640 +eval_length: 2880 +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/val_data/etth2.yaml b/cli/conf/new_finetune/val_data/etth2.yaml index 5fc653c..31ca968 100644 --- a/cli/conf/new_finetune/val_data/etth2.yaml +++ b/cli/conf/new_finetune/val_data/etth2.yaml @@ -1,13 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: ETTh2_eval - offset: 8640 # Same as _lsf_dataset.py - eval_length: 2880 # Same as _lsf_dataset.py - prediction_lengths: ??? - context_lengths: ??? - patch_sizes: ??? - -# prediction_lengths: [ 96, 192, 336, 720 ] -# context_lengths: [ 3000 ] -# patch_sizes: [ 32, 64 ] +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTh2_eval +offset: 8640 # Same as _lsf_dataset.py +eval_length: 2880 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/val_data/ettm1.yaml b/cli/conf/new_finetune/val_data/ettm1.yaml index e6a15b4..3f0244c 100644 --- a/cli/conf/new_finetune/val_data/ettm1.yaml +++ b/cli/conf/new_finetune/val_data/ettm1.yaml @@ -1,14 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: ETTm1_eval - offset: 34560 # Same as _lsf_dataset.py - eval_length: 11520 # Same as _lsf_dataset.py - prediction_lengths: ??? - context_lengths: ??? - patch_sizes: ??? - - -# prediction_lengths: [96, 192, 336, 720] -# context_lengths: [ 3000 ] -# patch_sizes: [ 32, 64, 128 ] # freq="15T" \ No newline at end of file +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTm1_eval +offset: 34560 # Same as _lsf_dataset.py +eval_length: 11520 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/val_data/ettm2.yaml b/cli/conf/new_finetune/val_data/ettm2.yaml index cb070fd..0939493 100644 --- a/cli/conf/new_finetune/val_data/ettm2.yaml +++ b/cli/conf/new_finetune/val_data/ettm2.yaml @@ -1,14 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: ETTm2_eval - offset: 34560 # Same as _lsf_dataset.py - eval_length: 11520 # Same as _lsf_dataset.py - prediction_lengths: ??? - context_lengths: ??? - patch_sizes: ??? - - -# prediction_lengths: [96, 192, 336, 720] -# context_lengths: [3000] -# patch_sizes: [32, 64, 128] # "freq=15T" +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: ETTm2_eval +offset: 34560 # Same as _lsf_dataset.py +eval_length: 11520 # Same as _lsf_dataset.py +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/conf/new_finetune/val_data/weather.yaml b/cli/conf/new_finetune/val_data/weather.yaml index 8f1973e..c2a23de 100644 --- a/cli/conf/new_finetune/val_data/weather.yaml +++ b/cli/conf/new_finetune/val_data/weather.yaml @@ -1,9 +1,7 @@ -_target_: uni2ts.data.builder.ConcatDatasetBuilder -_args_: - _target_: uni2ts.data.builder.simple.generate_eval_builders - dataset: weather_eval - offset: 36887 # Same as _lsf_dataset.py - eval_length: 5269 # Same as _lsf_dataset.py; test_length=10539 - prediction_lengths: ??? - context_lengths: ??? - patch_sizes: ??? \ No newline at end of file +_target_: uni2ts.data.builder.simple.generate_eval_builder +dataset: weather_eval +offset: 36887 # Same as _lsf_dataset.py +eval_length: 5269 # Same as _lsf_dataset.py; test_length=10539 +prediction_length: ??? +context_length: ??? +patch_size: ??? \ No newline at end of file diff --git a/cli/train.py b/cli/train.py index 99dc161..3d080d3 100644 --- a/cli/train.py +++ b/cli/train.py @@ -128,7 +128,7 @@ def main(cfg: DictConfig): model: L.LightningModule = instantiate(cfg.model, _convert_="all") - if 'collate_fn' not in cfg.train_dataloader: + if "collate_fn" not in cfg.train_dataloader: model.seq_fields = model.seq_fields + ("sample_id",) if cfg.compile: @@ -151,8 +151,26 @@ def main(cfg: DictConfig): ) L.seed_everything(cfg.seed + trainer.logger.version, workers=True) - print("Number of windows in train: ", train_dataset.dataset_weight * train_dataset.num_ts) + print( + "Number of windows in train: ", + train_dataset.dataset_weight * train_dataset.num_ts, + ) + print("Batch size for train: ", cfg.train_dataloader.batch_size) + print( + "Number of batches in a epoch: ", + train_dataset.dataset_weight + * train_dataset.num_ts + // cfg.train_dataloader.batch_size, + ) + print("Number of windows in val: ", val_dataset.dataset_weight * val_dataset.num_ts) + print("Batch size for val: ", cfg.val_dataloader.batch_size) + print( + "Number of batches in a epoch: ", + val_dataset.dataset_weight + * val_dataset.num_ts + // cfg.val_dataloader.batch_size, + ) # Validate before training, check the performance of original pretrained model. trainer.validate(model, datamodule=DataModule(cfg, train_dataset, val_dataset)) diff --git a/project/multi_scale_finetune/eval/small/run_multi.sh b/project/multi_scale_finetune/eval/small/run_multi.sh index f13c439..28b7892 100644 --- a/project/multi_scale_finetune/eval/small/run_multi.sh +++ b/project/multi_scale_finetune/eval/small/run_multi.sh @@ -1,8 +1,8 @@ #!/bin/bash -bash project/multi_scale_fintune/eval/small/etth1.sh -bash project/multi_scale_fintune/eval/small/etth2.sh -bash project/multi_scale_fintune/eval/small/ettm1.sh -bash project/multi_scale_fintune/eval/small/ettm2.sh -bash project/multi_scale_fintune/eval/small/weather.sh +bash project/multi_scale_finetune/eval/small/etth1.sh +bash project/multi_scale_finetune/eval/small/etth2.sh +bash project/multi_scale_finetune/eval/small/ettm1.sh +bash project/multi_scale_finetune/eval/small/ettm2.sh +bash project/multi_scale_finetune/eval/small/weather.sh diff --git a/project/new_finetune/eval/small/electricity.sh b/project/new_finetune/eval/small/electricity.sh new file mode 100644 index 0000000..476719b --- /dev/null +++ b/project/new_finetune/eval/small/electricity.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + + +cpp1='' +cpp2='' +cpp3='' +cpp4='' +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=64 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=electricity \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done + diff --git a/project/new_finetune/eval/small/etth1.sh b/project/new_finetune/eval/small/etth1.sh new file mode 100644 index 0000000..5b1baff --- /dev/null +++ b/project/new_finetune/eval/small/etth1.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + +cpp1='./outputs/multi_scale_finetune/moirai_1.1_R_small/dev/full/etth1/cl3000_pl96/checkpoints/epoch_3-step_40.ckpt' +cpp2='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/etth1/cl3000_pl192/checkpoints/epoch_1-step_20.ckpt' +cpp3='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/etth1/cl3000_pl336/checkpoints/epoch_1-step_20.ckpt' +cpp4='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/etth1/cl3000_pl720/checkpoints/epoch_1-step_20.ckpt' + +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=64 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=ETTh1 \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done \ No newline at end of file diff --git a/project/new_finetune/eval/small/etth2.sh b/project/new_finetune/eval/small/etth2.sh new file mode 100644 index 0000000..c4ab237 --- /dev/null +++ b/project/new_finetune/eval/small/etth2.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + + +cpp1='./outputs/finetune/moirai_1.1_R_small/lsf/norm/etth2/cl3000_pl96/checkpoints/epoch_3-step_400.ckpt' +cpp2='./outputs/finetune/moirai_1.1_R_small/lsf/norm/etth2/cl3000_pl192/checkpoints/epoch_4-step_500.ckpt' +cpp3='./outputs/finetune/moirai_1.1_R_small/lsf/norm/etth2/cl3000_pl336/checkpoints/epoch_4-step_500.ckpt' +cpp4='./outputs/finetune/moirai_1.1_R_small/lsf/norm/etth2/cl3000_pl720/checkpoints/epoch_6-step_700.ckpt' + +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=64 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=ETTh2 \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done diff --git a/project/new_finetune/eval/small/ettm1.sh b/project/new_finetune/eval/small/ettm1.sh new file mode 100644 index 0000000..77c67a2 --- /dev/null +++ b/project/new_finetune/eval/small/ettm1.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + + +cpp1='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl96/checkpoints/epoch_6-step_70.ckpt' +cpp2='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl192/checkpoints/epoch_2-step_30.ckpt' +cpp3='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl336/checkpoints/epoch_1-step_20.ckpt' +cpp4='./outputs/multi_scale_finetune/moirai_1.1_R_small/lsf/full/ettm1/cl3000_pl720/checkpoints/epoch_1-step_20.ckpt' + +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=128 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=ETTm1 \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done diff --git a/project/new_finetune/eval/small/ettm2.sh b/project/new_finetune/eval/small/ettm2.sh new file mode 100644 index 0000000..3ad4c7c --- /dev/null +++ b/project/new_finetune/eval/small/ettm2.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + + +cpp1='./outputs/finetune/moirai_1.1_R_small/lsf/norm/ettm2/cl3000_pl96/checkpoints/epoch_0-step_100.ckpt' +cpp2='./outputs/finetune/moirai_1.1_R_small/lsf/norm/ettm2/cl3000_pl192/checkpoints/epoch_61-step_6200.ckpt' +cpp3='./outputs/finetune/moirai_1.1_R_small/lsf/norm/ettm2/cl3000_pl336/checkpoints/epoch_10-step_1100.ckpt' +cpp4='./outputs/finetune/moirai_1.1_R_small/lsf/norm/ettm2/cl3000_pl720/checkpoints/epoch_11-step_1200.ckpt' + + +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=64 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=ETTm2 \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done diff --git a/project/new_finetune/eval/small/run_multi.sh b/project/new_finetune/eval/small/run_multi.sh new file mode 100644 index 0000000..42d1af1 --- /dev/null +++ b/project/new_finetune/eval/small/run_multi.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +bash project/new_finetune/eval/small/etth1.sh +bash project/new_finetune/eval/small/etth2.sh +bash project/new_finetune/eval/small/ettm1.sh +bash project/new_finetune/eval/small/ettm2.sh +bash project/new_finetune/eval/small/weather.sh + diff --git a/project/new_finetune/eval/small/weather.sh b/project/new_finetune/eval/small/weather.sh new file mode 100644 index 0000000..5a4276f --- /dev/null +++ b/project/new_finetune/eval/small/weather.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1 +export CUDA_VISIBLE_DEVICES=0 + +mode=S +cp=conf/new_eval +exp_name=lsf_finetune +cl=3000 +model=moirai_lightning_ckpt + + +cpp1='./outputs/finetune/moirai_1.1_R_small/lsf/param_proj/weather/cl3000_pl96/checkpoints/epoch_0-step_300.ckpt' +cpp2='./outputs/finetune/moirai_1.1_R_small/lsf/param_proj/weather/cl3000_pl192/checkpoints/epoch_120-step_36300.ckpt' +cpp3='./outputs/finetune/moirai_1.1_R_small/lsf/param_proj/weather/cl3000_pl336/checkpoints/epoch_154-step_46500.ckpt' +cpp4='./outputs/finetune/moirai_1.1_R_small/lsf/param_proj/weather/cl3000_pl720/checkpoints/epoch_91-step_27600.ckpt' + +index=1 +for pl in 96 192 336 720; do + case $index in + 1) cpp=$cpp1 ;; + 2) cpp=$cpp2 ;; + 3) cpp=$cpp3 ;; + 4) cpp=$cpp4 ;; + esac + + pretrained_model=$(echo $cpp | cut -d'/' -f4) + ft_pattern=$(echo $cpp | cut -d'/' -f6) + + python -m cli.eval \ + -cp $cp \ + exp_name=$exp_name/$pretrained_model/$ft_pattern \ + model=$model \ + model.patch_size=128 \ + model.context_length=$cl \ + model.checkpoint_path=$cpp \ + model.pretrained_checkpoint_path=ckpt/$pretrained_model.ckpt \ + data=lsf_test \ + data.dataset_name=weather \ + data.mode=$mode \ + data.prediction_length=$pl + + index=$((index+1)) +done diff --git a/project/new_finetune/train/small/electricity.sh b/project/new_finetune/train/small/electricity.sh new file mode 100644 index 0000000..19e7478 --- /dev/null +++ b/project/new_finetune/train/small/electricity.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=electricity +ps=64 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/project/new_finetune/train/small/etth1.sh b/project/new_finetune/train/small/etth1.sh new file mode 100644 index 0000000..60b216a --- /dev/null +++ b/project/new_finetune/train/small/etth1.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=etth1 +ps=64 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/project/new_finetune/train/small/etth2.sh b/project/new_finetune/train/small/etth2.sh new file mode 100644 index 0000000..a09c28b --- /dev/null +++ b/project/new_finetune/train/small/etth2.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=etth2 +ps=64 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/project/new_finetune/train/small/ettm1.sh b/project/new_finetune/train/small/ettm1.sh new file mode 100644 index 0000000..e6798cb --- /dev/null +++ b/project/new_finetune/train/small/ettm1.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=ettm1 +ps=128 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/project/new_finetune/train/small/ettm2.sh b/project/new_finetune/train/small/ettm2.sh new file mode 100644 index 0000000..705dd65 --- /dev/null +++ b/project/new_finetune/train/small/ettm2.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=ettm2 +ps=64 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/project/new_finetune/train/small/run_multi.sh b/project/new_finetune/train/small/run_multi.sh new file mode 100644 index 0000000..fe18de5 --- /dev/null +++ b/project/new_finetune/train/small/run_multi.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +bash project/new_finetune/train/small/etth1.sh +bash project/new_finetune/train/small/etth2.sh +bash project/new_finetune/train/small/ettm1.sh +bash project/new_finetune/train/small/ettm2.sh +bash project/new_finetune/train/small/weather.sh \ No newline at end of file diff --git a/project/new_finetune/train/small/weather.sh b/project/new_finetune/train/small/weather.sh new file mode 100644 index 0000000..70dadcf --- /dev/null +++ b/project/new_finetune/train/small/weather.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +export HYDRA_FULL_ERROR=1; export CUDA_VISIBLE_DEVICES=0; + +model=moirai_1.1_R_small +cp=conf/new_finetune +exp_name=lsf +cl=3000 +ft_pattern=full + +data=weather +ps=128 + +for pl in 96 192 336 720; do + python -m cli.train \ + -cp $cp \ + exp_name=$exp_name \ + run_name=cl${cl}_pl${pl} \ + model=$model \ + model.patch_size=${ps} \ + model.context_length=$cl \ + model.prediction_length=$pl \ + model.finetune_pattern=$ft_pattern \ + data=${data} \ + data.patch_size=${ps} \ + data.context_length=$cl \ + data.prediction_length=$pl \ + val_data=${data} \ + val_data.patch_sizes=${ps} \ + val_data.context_lengths=$cl \ + val_data.prediction_lengths=$pl +done \ No newline at end of file diff --git a/src/uni2ts/data/builder/simple.py b/src/uni2ts/data/builder/simple.py index 9092214..b0c7a6b 100644 --- a/src/uni2ts/data/builder/simple.py +++ b/src/uni2ts/data/builder/simple.py @@ -26,7 +26,12 @@ from uni2ts.common.env import env from uni2ts.common.typing import GenFunc -from uni2ts.data.dataset import EvalDataset, SampleTimeSeriesType, TimeSeriesDataset, FinetuneDataset +from uni2ts.data.dataset import ( + EvalDataset, + FinetuneDataset, + SampleTimeSeriesType, + TimeSeriesDataset, +) from uni2ts.data.indexer import HuggingFaceDatasetIndexer from uni2ts.transform import Transformation @@ -334,25 +339,6 @@ def scale(self, data, start, end): return (data - self.mean) / self.std -def generate_finetune_builder( - dataset: str, - train_length: int, - prediction_length: int, - context_length: int, - patch_size: int, - storage_path: Path = env.CUSTOM_DATA_PATH, -) -> SimpleFinetuneDatasetBuilder: - return SimpleFinetuneDatasetBuilder( - dataset=dataset, - windows=train_length - context_length - prediction_length + 1, - distance=1, - prediction_length=prediction_length, - context_length=context_length, - patch_size=patch_size, - storage_path=storage_path, - ) - - @dataclass class SimpleEvalDatasetBuilder(DatasetBuilder): dataset: str @@ -419,6 +405,81 @@ def load_dataset( ) +def generate_finetune_builder( + dataset: str, + train_length: int, + prediction_length: int, + context_length: int, + patch_size: int, + storage_path: Path = env.CUSTOM_DATA_PATH, +) -> SimpleFinetuneDatasetBuilder: + """ + Set distance=1 for training data. Same as standard LSF setting. + """ + return SimpleFinetuneDatasetBuilder( + dataset=dataset, + windows=train_length - context_length - prediction_length + 1, + distance=1, + prediction_length=prediction_length, + context_length=context_length, + patch_size=patch_size, + storage_path=storage_path, + ) + + +def generate_eval_builder( + dataset: str, + offset: int, + eval_length: int, + prediction_length: int, + context_length: int, + patch_size: int, + storage_path: Path = env.CUSTOM_DATA_PATH, +) -> SimpleEvalDatasetBuilder: + """ + Set distance according to dataset. Decrease the number of validation samples to reduce computational cost. + """ + distances = { + "ETTh1_eval": 13, # 13h + "ETTh2_eval": 13, + "ETTm1_eval": 25, # 6h 15min + "ETTm2_eval": 25, + "weather_eval": 37, # 6h 10 min + "electricity_eval": 49, # 2d 1h + } + if dataset in distances: + distance = distances[dataset] + windows = (eval_length - prediction_length) // distance + 1 + else: + distance = prediction_length + windows = eval_length // prediction_length + + # base = 8 # base can change for different datasets + # overlap_ratio = { + # 96: base, + # 192: 2*base, + # 336: 4*base, + # 720: 8*base, + # } + # if prediction_length in overlap_ratio: + # distance = prediction_length // overlap_ratio[prediction_length] + # windows = (eval_length - prediction_length) // distance + 1 + # else: + # distance = prediction_length + # windows = eval_length // prediction_length + + return SimpleEvalDatasetBuilder( + dataset=dataset, + offset=offset, + windows=windows, + distance=distance, + prediction_length=prediction_length, + context_length=context_length, + patch_size=patch_size, + storage_path=storage_path, + ) + + def generate_eval_builders( dataset: str, offset: int, @@ -434,8 +495,6 @@ def generate_eval_builders( offset=offset, windows=eval_length // pred, distance=pred, - # windows=eval_length - pred + 1, - # distance=1, prediction_length=pred, context_length=ctx, patch_size=psz, diff --git a/src/uni2ts/data/dataset.py b/src/uni2ts/data/dataset.py index 29a5aac..5e01e7d 100644 --- a/src/uni2ts/data/dataset.py +++ b/src/uni2ts/data/dataset.py @@ -212,9 +212,7 @@ def _get_data(self, idx: int) -> dict[str, Data]: class FinetuneDataset(TimeSeriesDataset): - """ - - """ + """ """ def __init__( self, diff --git a/src/uni2ts/model/new_moirai/finetune.py b/src/uni2ts/model/new_moirai/finetune.py index 1b3a3a3..4614b93 100644 --- a/src/uni2ts/model/new_moirai/finetune.py +++ b/src/uni2ts/model/new_moirai/finetune.py @@ -42,6 +42,7 @@ from uni2ts.optim import SchedulerType, get_scheduler from uni2ts.transform import ( AddObservedMask, + AddSampleIndex, AddTimeIndex, AddVariateIndex, DefaultPatchSizeConstraints, @@ -50,6 +51,7 @@ EvalMaskedPrediction, EvalPad, ExtendMask, + FinetunePatchCrop, FixedPatchSizeConstraints, FlatPackCollection, FlatPackFields, @@ -66,8 +68,6 @@ SelectFields, SequencifyField, Transformation, - FinetunePatchCrop, - AddSampleIndex ) from .module import MoiraiModule @@ -292,7 +292,7 @@ def configure_optimizers(self) -> dict: if "in_proj" in pn: p.requires_grad = True - if "norm" in self.finetune_pattern: # + if "norm" in self.finetune_pattern: for pn, p in self.named_parameters(): if "norm1" in pn or "norm2" in pn: p.requires_grad = True @@ -409,17 +409,16 @@ def configure_optimizers(self) -> dict: eps=1e-6, ) scheduler = get_scheduler( - SchedulerType.REDUCE_ON_PLATEAU, + SchedulerType.CONSTANT, # Use constant lr scheduler optimizer, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=self.hparams.num_training_steps, - # scheduler_specific_kwargs={'monitor': "val/{self.hparams.loss_func.__class__.__name__}"} ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, - "monitor": "lr-AdamW/pg1", + "monitor": "train_loss", "interval": "step", }, } @@ -503,9 +502,9 @@ def default_train_transform( + AddSampleIndex( fields=("target",), optional_fields=("past_feat_dynamic_real",), - sample_id_field = "sample_id", - expected_ndim = 3, - collection_type = dict, + sample_id_field="sample_id", + expected_ndim=3, + collection_type=dict, ) + EvalMaskedPrediction( mask_length=math.ceil(prediction_length / patch_size), @@ -530,8 +529,8 @@ def default_train_transform( feat=False, ) + FlatPackCollection( - field="sample_id", - feat=False, + field="sample_id", + feat=False, ) + FlatPackCollection( field="prediction_mask", diff --git a/src/uni2ts/module/attention.py b/src/uni2ts/module/attention.py index fb2245e..1b70302 100644 --- a/src/uni2ts/module/attention.py +++ b/src/uni2ts/module/attention.py @@ -109,27 +109,26 @@ def __init__( self.attn_dropout_p = attn_dropout_p self.out_proj = nn.Linear(dim, dim, bias=bias) - self.query_filmed_generator = nn.ModuleList( - [ - nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length - nn.Linear(in_features=dim, out_features=2 * 6) - ] - ) - - self.key_filmed_generator = nn.ModuleList( - [ - nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length - nn.Linear(in_features=dim, out_features=2 * 6) - ] - ) - - # self.value_filmed_generator = nn.ModuleList( + # self.query_filmed_generator = nn.ModuleList( # [ # nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length # nn.Linear(in_features=dim, out_features=2 * 6) # ] # ) - + # + # self.key_filmed_generator = nn.ModuleList( + # [ + # nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length + # nn.Linear(in_features=dim, out_features=2 * 6) + # ] + # ) + # + # # self.value_filmed_generator = nn.ModuleList( + # # [ + # # nn.Linear(in_features=dim, out_features=2 * 12), # each scale's length + # # nn.Linear(in_features=dim, out_features=2 * 6) + # # ] + # # ) def _get_var_id( self, @@ -203,7 +202,7 @@ def _update_attn_mask( ) -> Optional[ Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] - ]: + ]: if attn_mask is not None: attn_mask = rearrange( attn_mask, @@ -214,11 +213,14 @@ def _update_attn_mask( # Bias scalars are different in different groups. if self.var_attn_bias is not None: - attn_bias = attn_bias + self.var_attn_bias( # 2 scales for same-variate and different-variate positions - query, - key, - query_id=query_var_id, - kv_id=kv_var_id, + attn_bias = ( + attn_bias + + self.var_attn_bias( # 2 scales for same-variate and different-variate positions + query, + key, + query_id=query_var_id, + kv_id=kv_var_id, + ) ) if self.time_attn_bias is not None: @@ -269,7 +271,7 @@ def get_token_index_of_target_variate_per_sample( self, variate_id: Int[torch.Tensor, "*batch q_len"], attn_mask: Bool[torch.Tensor, "*batch q_len kv_len"], - target_variate: int = 1 # Default to variate_id = 1 + target_variate: int = 1, # Default to variate_id = 1 ): # ToDo: 当前假设batch中所有的variate_id和attn_mask是一样的 @@ -284,15 +286,21 @@ def get_token_index_of_target_variate_per_sample( # Step 2: 使用无序组合,确保 (variate_1, variate_2) 和 (variate_2, variate_1) 的映射相同 # 扩展维度以广播到 q_len x q_len 矩阵 - variate_id_min = torch.minimum(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)) - variate_id_max = torch.maximum(variate_id.unsqueeze(-1), variate_id.unsqueeze(-2)) + variate_id_min = torch.minimum( + variate_id.unsqueeze(-1), variate_id.unsqueeze(-2) + ) + variate_id_max = torch.maximum( + variate_id.unsqueeze(-1), variate_id.unsqueeze(-2) + ) # 使用偏移量生成唯一组合值,确保无序组合的对称性 variate_pair_matrix = variate_id_min * (max_variate_id + 1) + variate_id_max # Step 3: 找到 variate_id = target_variate 的组合 - target_combination_value = target_variate * (max_variate_id + 1) + target_variate - variate_target_mask = (variate_pair_matrix == target_combination_value) + target_combination_value = ( + target_variate * (max_variate_id + 1) + target_variate + ) + variate_target_mask = variate_pair_matrix == target_combination_value # Step 4: 用 attn_mask 进行 AND 运算,筛选出每个 sample 内 variate_id = target_variate 的组合 final_mask = variate_target_mask & attn_mask @@ -311,15 +319,15 @@ def get_token_index_of_target_variate_per_sample( return index_per_sample def forward( - self, - query: Float[torch.Tensor, "*batch q_len dim"], - key: Float[torch.Tensor, "*batch kv_len dim"], - value: Float[torch.Tensor, "*batch kv_len dim"], - attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, - query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, - kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, - query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, - kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + self, + query: Float[torch.Tensor, "*batch q_len dim"], + key: Float[torch.Tensor, "*batch kv_len dim"], + value: Float[torch.Tensor, "*batch kv_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, ) -> Float[torch.Tensor, "*batch q_len dim"]: query = self.q_proj(query) key = self.k_proj(key) @@ -353,7 +361,6 @@ def forward( # # -1) / 2):] # # value[..., index, :] = value_weight.unsqueeze(-1) * value_i + value_bias.unsqueeze(-1) - query = self.q_norm( rearrange( query, @@ -478,18 +485,18 @@ def __init__( class FilmedGroupedQueryAttention(nn.Module): def __init__( - self, - dim: int, - num_heads: int, - num_groups: int, - bias: bool = True, - norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, - softmax_scale: Optional[float] = None, - attn_dropout_p: float = 0.0, - var_attn_bias: Optional[Callable[[], AttentionBias]] = None, - time_attn_bias: Optional[Callable[[], AttentionBias]] = None, - var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, - time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + self, + dim: int, + num_heads: int, + num_groups: int, + bias: bool = True, + norm_layer: Optional[type[nn.Module] | partial[nn.Module]] = nn.LayerNorm, + softmax_scale: Optional[float] = None, + attn_dropout_p: float = 0.0, + var_attn_bias: Optional[Callable[[], AttentionBias]] = None, + time_attn_bias: Optional[Callable[[], AttentionBias]] = None, + var_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, + time_qk_proj: Optional[Callable[[], QueryKeyProjection]] = None, ): super().__init__() assert num_heads > 0 and dim % num_heads == 0 @@ -526,11 +533,11 @@ def __init__( # ) def _get_var_id( - self, - query: Float[torch.Tensor, "*batch group hpg q_len dim"], - key: Float[torch.Tensor, "*batch group hpg kv_len dim"], - query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]], - kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]], ) -> tuple[ Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], @@ -555,11 +562,11 @@ def _get_var_id( return query_var_id, kv_var_id def _get_time_id( - self, - query: Float[torch.Tensor, "*batch group hpg q_len dim"], - key: Float[torch.Tensor, "*batch group hpg kv_len dim"], - query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]], - kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]], + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]], ) -> tuple[ Optional[Int[torch.Tensor, "*batch 1 1 q_len"]], Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]], @@ -586,18 +593,18 @@ def _get_time_id( return query_time_id, kv_time_id def _update_attn_mask( - self, - attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]], - query: Float[torch.Tensor, "*batch group hpg q_len dim"], - key: Float[torch.Tensor, "*batch group hpg kv_len dim"], - query_var_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, - kv_var_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, - query_time_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, - kv_time_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + self, + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]], + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch 1 1 q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch 1 1 kv_len"]] = None, ) -> Optional[ Bool[torch.Tensor, "*batch #group #hpg q_len kv_len"] | Float[torch.Tensor, "*batch #group #hpg q_len kv_len"] - ]: + ]: if attn_mask is not None: attn_mask = rearrange( attn_mask, @@ -608,11 +615,14 @@ def _update_attn_mask( # Bias scalars are different in different groups. if self.var_attn_bias is not None: - attn_bias = attn_bias + self.var_attn_bias( # 2 scales for same-variate and different-variate positions - query, - key, - query_id=query_var_id, - kv_id=kv_var_id, + attn_bias = ( + attn_bias + + self.var_attn_bias( # 2 scales for same-variate and different-variate positions + query, + key, + query_id=query_var_id, + kv_id=kv_var_id, + ) ) if self.time_attn_bias is not None: @@ -630,19 +640,19 @@ def _update_attn_mask( attn_bias if attn_mask is None else attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) - # Mask out positions from differnet samples + # Mask out positions from differnet samples ) ) return attn_mask def _qk_proj( - self, - query: Float[torch.Tensor, "*batch group hpg q_len dim"], - key: Float[torch.Tensor, "*batch group hpg kv_len dim"], - query_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], - kv_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], - query_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], - kv_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + self, + query: Float[torch.Tensor, "*batch group hpg q_len dim"], + key: Float[torch.Tensor, "*batch group hpg kv_len dim"], + query_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_var_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], + query_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg q_len"]], + kv_time_id: Optional[Int[torch.Tensor, "*batch #group #hpg kv_len"]], ) -> tuple[ Float[torch.Tensor, "*batch group hpg q_len dim"], Float[torch.Tensor, "*batch group hpg kv_len dim"], @@ -660,15 +670,15 @@ def _qk_proj( return query, key def forward( - self, - query: Float[torch.Tensor, "*batch q_len dim"], - key: Float[torch.Tensor, "*batch kv_len dim"], - value: Float[torch.Tensor, "*batch kv_len dim"], - attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, - query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, - kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, - query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, - kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + self, + query: Float[torch.Tensor, "*batch q_len dim"], + key: Float[torch.Tensor, "*batch kv_len dim"], + value: Float[torch.Tensor, "*batch kv_len dim"], + attn_mask: Optional[Bool[torch.Tensor, "*batch q_len kv_len"]] = None, + query_var_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_var_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, + query_time_id: Optional[Int[torch.Tensor, "*batch q_len"]] = None, + kv_time_id: Optional[Int[torch.Tensor, "*batch kv_len"]] = None, ) -> Float[torch.Tensor, "*batch q_len dim"]: query = self.q_proj(query) key = self.k_proj(key) @@ -676,7 +686,6 @@ def forward( # ToDo: Plan B: Directly apply different Film on query / key to different scales. W.o revising RoPE - query = self.q_norm( rearrange( query, @@ -740,4 +749,4 @@ def forward( scale=self.softmax_scale, ) out = rearrange(out, "... group hpg q_len dim -> ... q_len (group hpg dim)") - return self.out_proj(out) \ No newline at end of file + return self.out_proj(out) diff --git a/src/uni2ts/module/position/attn_projection.py b/src/uni2ts/module/position/attn_projection.py index 3b86a91..e5f6840 100644 --- a/src/uni2ts/module/position/attn_projection.py +++ b/src/uni2ts/module/position/attn_projection.py @@ -145,7 +145,9 @@ def __init__( kwargs: Optional[dict[str, Any]] = None, key_proj_layer: Optional[type[Projection]] = None, key_kwargs: Optional[dict[str, Any]] = None, - partial_factor: Optional[tuple[float, float]] = None, # QZ: Only rotate part of embedding dimension + partial_factor: Optional[ + tuple[float, float] + ] = None, # QZ: Only rotate part of embedding dimension ): super().__init__() if partial_factor is not None: diff --git a/src/uni2ts/module/transformer.py b/src/uni2ts/module/transformer.py index 3558d1b..84c2b61 100644 --- a/src/uni2ts/module/transformer.py +++ b/src/uni2ts/module/transformer.py @@ -141,7 +141,7 @@ def __init__( ) get_self_attn = partial( - GroupedQueryAttention, # ToDo: If I change it, can I load MoiraiModule with original ckpt? + GroupedQueryAttention, # ToDo: If I change it, can I load MoiraiModule with original ckpt? dim=d_model, num_heads=num_heads, num_groups=num_groups, diff --git a/src/uni2ts/transform/__init__.py b/src/uni2ts/transform/__init__.py index e80eca6..759804e 100644 --- a/src/uni2ts/transform/__init__.py +++ b/src/uni2ts/transform/__init__.py @@ -14,8 +14,8 @@ # limitations under the License. from ._base import Chain, Identity, Transformation -from .crop import EvalCrop, PatchCrop, PatchCropGivenFixedConfig, FinetunePatchCrop -from .feature import AddObservedMask, AddTimeIndex, AddVariateIndex, AddSampleIndex +from .crop import EvalCrop, FinetunePatchCrop, PatchCrop, PatchCropGivenFixedConfig +from .feature import AddObservedMask, AddSampleIndex, AddTimeIndex, AddVariateIndex from .field import LambdaSetFieldIfNotPresent, RemoveFields, SelectFields, SetValue from .imputation import DummyValueImputation, ImputeTimeSeries, LastValueImputation from .multi_scale import ( @@ -102,5 +102,5 @@ "GetSeasonalNaivePrediction", "AddSeasonalNaiveTarget", "SeasonalNaiveEvalCrop", - "FinetunePatchCrop" + "FinetunePatchCrop", ] diff --git a/src/uni2ts/transform/crop.py b/src/uni2ts/transform/crop.py index b0a3a99..e671828 100644 --- a/src/uni2ts/transform/crop.py +++ b/src/uni2ts/transform/crop.py @@ -224,11 +224,9 @@ def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: return a, b - @dataclass class FinetunePatchCrop(MapFuncMixin, Transformation): - """ - """ + """ """ distance: int prediction_length: int @@ -261,4 +259,3 @@ def _get_boundaries(self, data_entry: dict[str, Any]) -> tuple[int, int]: assert time >= b > a >= 0 return a, b -