Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precomputation of conditions and latents #129

Merged
merged 12 commits into from
Dec 23, 2024
119 changes: 116 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,60 @@ video = pipe("<my-awesome-prompt>").frames[0]
export_to_video(video, "output.mp4", fps=8)
```

### Memory Usage

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**:
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

```
Training configuration: {
"trainable parameters": 117440512,
"total samples": 69,
"train epochs": 1,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 69,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------:|:----------------:|:-------------------:|
| before training start | 13.486 | 13.879 |
| before validation start | 14.146 | 17.623 |
| after validation end | 14.146 | 17.623 |
| after epoch 1 | 14.146 | 17.623 |
| after training end | 4.461 | 17.623 |
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

Note: requires about `18` GB of VRAM without precomputation.

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**:

```
Training configuration: {
"trainable parameters": 117440512,
"total samples": 1,
"train epochs": 10,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 1,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------------:|:----------------:|:-------------------:|
| after precomputing conditions | 8.88 | 8.920 |
| after precomputing latents | 9.684 | 11.613 |
| before training start | 3.809 | 10.010 |
| after epoch 1 | 4.26 | 10.916 |
| before validation start | 4.26 | 10.916 |
| after validation end | 13.924 | 17.262 |
| after training end | 4.26 | 14.314 |

Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB.

</details>

<details>
Expand All @@ -169,8 +223,7 @@ OUTPUT_DIR="/path/to/models/hunyuan-video/hunyuan-video-loras/hunyuan-video_caki

# Model arguments
model_cmd="--model_name hunyuan_video \
--pretrained_model_name_or_path tencent/HunyuanVideo
--revision refs/pr/18"
--pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo"
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
Expand Down Expand Up @@ -252,7 +305,7 @@ import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "tencent/HunyuanVideo"
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
Expand All @@ -272,10 +325,70 @@ output = pipe(
export_to_video(output, "output.mp4", fps=15)
```

### Memory Usage

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**:

```
Training configuration: {
"trainable parameters": 163577856,
"total samples": 69,
"train epochs": 1,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 69,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------:|:----------------:|:-------------------:|
| before training start | 38.889 | 39.020 |
| before validation start | 39.747 | 56.266 |
| after validation end | 39.748 | 58.385 |
| after epoch 1 | 39.748 | 40.910 |
| after training end | 25.288 | 40.910 |

Note: requires about `59` GB of VRAM without precomputation.

LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**:

```
Training configuration: {
"trainable parameters": 163577856,
"total samples": 1,
"train epochs": 10,
"train steps": 10,
"batches per device": 1,
"total batches observed per epoch": 1,
"train batch size": 1,
"gradient accumulation steps": 1
}
```

| stage | memory_allocated | max_memory_reserved |
|:-----------------------------:|:----------------:|:-------------------:|
| after precomputing conditions | 14.232 | 14.461 |
| after precomputing latents | 14.717 | 17.244 |
| before training start | 24.195 | 26.039 |
| after epoch 1 | 24.83 | 42.387 |
| before validation start | 24.842 | 42.387 |
| after validation end | 39.558 | 46.947 |
| after training end | 24.842 | 41.039 |

Note: requires about `47` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to about `42` GB.

</details>

If you would like to use a custom dataset, refer to the dataset preparation guide [here](./assets/dataset.md).

> [!NOTE]
> To lower memory requirements:
> - Pass `--precompute_conditions` when launching training.
> - Pass `--gradient_checkpointing` when launching training.
> - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs.

## Memory requirements

<table align="center">
Expand Down
39 changes: 39 additions & 0 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
from typing import Any, Dict, List, Optional, Tuple

import torch

from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS


Expand All @@ -20,6 +22,11 @@ class Args:
revision: Optional[str] = None
variant: Optional[str] = None
cache_dir: Optional[str] = None
text_encoder_dtype: torch.dtype = torch.bfloat16
text_encoder_2_dtype: torch.dtype = torch.bfloat16
text_encoder_3_dtype: torch.dtype = torch.bfloat16
transformer_dtype: torch.dtype = torch.bfloat16
vae_dtype: torch.dtype = torch.bfloat16

# Dataset arguments
data_root: str = None
Expand All @@ -32,6 +39,7 @@ class Args:
video_reshape_mode: Optional[str] = None
caption_dropout_p: float = 0.00
caption_dropout_technique: str = "empty"
precompute_conditions: bool = False

# Dataloader arguments
dataloader_num_workers: int = 0
Expand Down Expand Up @@ -113,6 +121,11 @@ def to_dict(self) -> Dict[str, Any]:
"revision": self.revision,
"variant": self.variant,
"cache_dir": self.cache_dir,
"text_encoder_dtype": self.text_encoder_dtype,
"text_encoder_2_dtype": self.text_encoder_2_dtype,
"text_encoder_3_dtype": self.text_encoder_3_dtype,
"transformer_dtype": self.transformer_dtype,
"vae_dtype": self.vae_dtype,
},
"dataset_arguments": {
"data_root": self.data_root,
Expand All @@ -124,6 +137,8 @@ def to_dict(self) -> Dict[str, Any]:
"video_resolution_buckets": self.video_resolution_buckets,
"video_reshape_mode": self.video_reshape_mode,
"caption_dropout_p": self.caption_dropout_p,
"caption_dropout_technique": self.caption_dropout_technique,
"precompute_conditions": self.precompute_conditions,
},
"dataloader_arguments": {
"dataloader_num_workers": self.dataloader_num_workers,
Expand Down Expand Up @@ -234,6 +249,11 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")


def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -317,6 +337,11 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
choices=["empty", "zero"],
help="Technique to use for caption dropout.",
)
parser.add_argument(
"--precompute_conditions",
action="store_true",
help="Whether or not to precompute the conditionings for the model.",
)


def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -645,6 +670,13 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
)


_DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}


def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args = Args()

Expand All @@ -654,6 +686,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.revision = args.revision
result_args.variant = args.variant
result_args.cache_dir = args.cache_dir
result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype]
result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype]

# Dataset arguments
if args.data_root is None and args.dataset_file is None:
Expand All @@ -668,6 +705,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
result_args.video_reshape_mode = args.video_reshape_mode
result_args.caption_dropout_p = args.caption_dropout_p
result_args.caption_dropout_technique = args.caption_dropout_technique
result_args.precompute_conditions = args.precompute_conditions

# Dataloader arguments
result_args.dataloader_num_workers = args.dataloader_num_workers
Expand Down
4 changes: 4 additions & 0 deletions finetrainers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")

PRECOMPUTED_DIR_NAME = "precomputed"
PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions"
PRECOMPUTED_LATENTS_DIR_NAME = "latents"

MODEL_DESCRIPTION = r"""
\# {model_id} {training_type} finetune

Expand Down
30 changes: 30 additions & 0 deletions finetrainers/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
Expand All @@ -19,6 +20,9 @@

decord.bridge.set_bridge("torch")

from .constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME


logger = get_logger(__name__)


Expand Down Expand Up @@ -257,6 +261,32 @@ def _find_nearest_resolution(self, height, width):
return nearest_res[1], nearest_res[2]


class PrecomputedDataset(Dataset):
def __init__(self, data_root: str) -> None:
super().__init__()

self.data_root = Path(data_root)

self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME

self.latent_conditions = sorted(os.listdir(self.latents_path))
self.text_conditions = sorted(os.listdir(self.conditions_path))

assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"

def __len__(self) -> int:
return len(self.latent_conditions)

def __getitem__(self, index: int) -> Dict[str, Any]:
conditions = {}
latent_path = self.latents_path / self.latent_conditions[index]
condition_path = self.conditions_path / self.text_conditions[index]
conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
return conditions


class BucketSampler(Sampler):
r"""
PyTorch Sampler that groups 3D data by height, width and frames.
Expand Down
Loading