Skip to content

Commit

Permalink
Merge pull request #129 from a-r-r-o-w/condition-precomputation
Browse files Browse the repository at this point in the history
Precomputation of conditions and latents
  • Loading branch information
sayakpaul authored Dec 23, 2024
2 parents 223add1 + 2858346 commit cf9be17
Show file tree
Hide file tree
Showing 10 changed files with 697 additions and 124 deletions.
119 changes: 116 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,60 @@ video = pipe("<my-awesome-prompt>").frames[0]
export_to_video(video, "output.mp4", fps=8)
```
### Memory Usage
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
```
Training configuration: {
"trainable parameters": 117440512,
"total samples": 69,
"train epochs": 1,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 69,
"train batch size": 1,
"gradient accumulation steps": 1
}
```
| stage | memory_allocated | max_memory_reserved |
|:-----------------------:|:----------------:|:-------------------:|
| before training start | 13.486 | 13.879 |
| before validation start | 14.146 | 17.623 |
| after validation end | 14.146 | 17.623 |
| after epoch 1 | 14.146 | 17.623 |
| after training end | 4.461 | 17.623 |
Note: requires about `18` GB of VRAM without precomputation.
LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:
```
Training configuration: {
"trainable parameters": 117440512,
"total samples": 1,
"train epochs": 10,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 1,
"train batch size": 1,
"gradient accumulation steps": 1
}
```
| stage | memory_allocated | max_memory_reserved |
|:-----------------------------:|:----------------:|:-------------------:|
| after precomputing conditions | 8.88 | 8.920 |
| after precomputing latents | 9.684 | 11.613 |
| before training start | 3.809 | 10.010 |
| after epoch 1 | 4.26 | 10.916 |
| before validation start | 4.26 | 10.916 |
| after validation end | 13.924 | 17.262 |
| after training end | 4.26 | 14.314 |
Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.
</details>
<details>
Expand All @@ -169,8 +223,7 @@ OUTPUT_DIR="/path/to/models/hunyuan-video/hunyuan-video-loras/hunyuan-video_caki
# Model arguments
model_cmd="--model_name hunyuan_video \
--pretrained_model_name_or_path tencent/HunyuanVideo
--revision refs/pr/18"
--pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
Expand Down Expand Up @@ -252,7 +305,7 @@ import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "tencent/HunyuanVideo"
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
Expand All @@ -272,10 +325,70 @@ output = pipe(
export_to_video(output, "output.mp4", fps=15)
```

### Memory Usage

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:

```
Training configuration: {
"trainable parameters": 163577856,
"total samples": 69,
"train epochs": 1,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 69,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------:|:----------------:|:-------------------:|
| before training start | 38.889 | 39.020 |
| before validation start | 39.747 | 56.266 |
| after validation end | 39.748 | 58.385 |
| after epoch 1 | 39.748 | 40.910 |
| after training end | 25.288 | 40.910 |

Note: requires about `59` GB of VRAM without precomputation.

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:

```
Training configuration: {
"trainable parameters": 163577856,
"total samples": 1,
"train epochs": 10,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 1,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------------:|:----------------:|:-------------------:|
| after precomputing conditions | 14.232 | 14.461 |
| after precomputing latents | 14.717 | 17.244 |
| before training start | 24.195 | 26.039 |
| after epoch 1 | 24.83 | 42.387 |
| before validation start | 24.842 | 42.387 |
| after validation end | 39.558 | 46.947 |
| after training end | 24.842 | 41.039 |

Note: requires about `47` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to about `42` GB.

</details>

If you would like to use a custom dataset, refer to the dataset preparation guide [here](./assets/dataset.md).

> [!NOTE]
> To lower memory requirements:
> - Pass `--precompute_conditions` when launching training.
> - Pass `--gradient_checkpointing` when launching training.
> - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.
## Memory requirements

<table align="center">
Expand Down
39 changes: 39 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
from typing import Any, Dict, List, Optional, Tuple

import torch

from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS


Expand All @@ -20,6 +22,11 @@ class Args:
revision: Optional[str] = None
variant: Optional[str] = None
cache_dir: Optional[str] = None
text_encoder_dtype: torch.dtype = torch.bfloat16
text_encoder_2_dtype: torch.dtype = torch.bfloat16
text_encoder_3_dtype: torch.dtype = torch.bfloat16
transformer_dtype: torch.dtype = torch.bfloat16
vae_dtype: torch.dtype = torch.bfloat16

# Dataset arguments
data_root: str = None
Expand All @@ -32,6 +39,7 @@ class Args:
video_reshape_mode: Optional[str] = None
caption_dropout_p: float = 0.00
caption_dropout_technique: str = "empty"
precompute_conditions: bool = False

# Dataloader arguments
dataloader_num_workers: int = 0
Expand Down Expand Up @@ -113,6 +121,11 @@ def to_dict(self) -> Dict[str, Any]:
"revision": self.revision,
"variant": self.variant,
"cache_dir": self.cache_dir,
"text_encoder_dtype": self.text_encoder_dtype,
"text_encoder_2_dtype": self.text_encoder_2_dtype,
"text_encoder_3_dtype": self.text_encoder_3_dtype,
"transformer_dtype": self.transformer_dtype,
"vae_dtype": self.vae_dtype,
},
"dataset_arguments": {
"data_root": self.data_root,
Expand All @@ -124,6 +137,8 @@ def to_dict(self) -> Dict[str, Any]:
"video_resolution_buckets": self.video_resolution_buckets,
"video_reshape_mode": self.video_reshape_mode,
"caption_dropout_p": self.caption_dropout_p,
"caption_dropout_technique": self.caption_dropout_technique,
"precompute_conditions": self.precompute_conditions,
},
"dataloader_arguments": {
"dataloader_num_workers": self.dataloader_num_workers,
Expand Down Expand Up @@ -234,6 +249,11 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")


def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -317,6 +337,11 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
choices=["empty", "zero"],
help="Technique to use for caption dropout.",
)
parser.add_argument(
"--precompute_conditions",
action="store_true",
help="Whether or not to precompute the conditionings for the model.",
)


def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -645,6 +670,13 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
)


_DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}


def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args = Args()

Expand All @@ -654,6 +686,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.revision = args.revision
result_args.variant = args.variant
result_args.cache_dir = args.cache_dir
result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]

# Dataset arguments
if args.data_root is None and args.dataset_file is None:
Expand All @@ -668,6 +705,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
result_args.video_reshape_mode = args.video_reshape_mode
result_args.caption_dropout_p = args.caption_dropout_p
result_args.caption_dropout_technique = args.caption_dropout_technique
result_args.precompute_conditions = args.precompute_conditions

# Dataloader arguments
result_args.dataloader_num_workers = args.dataloader_num_workers
Expand Down
4 changes: 4 additions & 0 deletions finetrainers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")

PRECOMPUTED_DIR_NAME = "precomputed"
PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
PRECOMPUTED_LATENTS_DIR_NAME = "latents"

MODEL_DESCRIPTION = r"""
\# {model_id} {training_type} finetune
Expand Down
30 changes: 30 additions & 0 deletions finetrainers/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -19,6 +20,9 @@

decord.bridge.set_bridge("torch")

from .constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME


logger = get_logger(__name__)


Expand Down Expand Up @@ -257,6 +261,32 @@ def _find_nearest_resolution(self, height, width):
return nearest_res[1], nearest_res[2]


class PrecomputedDataset(Dataset):
def __init__(self, data_root: str) -> None:
super().__init__()

self.data_root = Path(data_root)

self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME

self.latent_conditions = sorted(os.listdir(self.latents_path))
self.text_conditions = sorted(os.listdir(self.conditions_path))

assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"

def __len__(self) -> int:
return len(self.latent_conditions)

def __getitem__(self, index: int) -> Dict[str, Any]:
conditions = {}
latent_path = self.latents_path / self.latent_conditions[index]
condition_path = self.conditions_path / self.text_conditions[index]
conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
return conditions


class BucketSampler(Sampler):
r"""
PyTorch Sampler that groups 3D data by height, width and frames.
Expand Down
Loading

0 comments on commit cf9be17

Please sign in to comment.