diff --git a/README.md b/README.md index f1eee8c..60dfc7f 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,60 @@ video = pipe("").frames[0] export_to_video(video, "output.mp4", fps=8) ``` +### Memory Usage + +LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **without precomputation**: + +``` +Training configuration: { + "trainable parameters": 117440512, + "total samples": 69, + "train epochs": 1, + "train steps": 10, + "batches per device": 1, + "total batches observed per epoch": 69, + "train batch size": 1, + "gradient accumulation steps": 1 +} +``` + +| stage | memory_allocated | max_memory_reserved | +|:-----------------------:|:----------------:|:-------------------:| +| before training start | 13.486 | 13.879 | +| before validation start | 14.146 | 17.623 | +| after validation end | 14.146 | 17.623 | +| after epoch 1 | 14.146 | 17.623 | +| after training end | 4.461 | 17.623 | + +Note: requires about `18` GB of VRAM without precomputation. + +LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolution, **with precomputation**: + +``` +Training configuration: { + "trainable parameters": 117440512, + "total samples": 1, + "train epochs": 10, + "train steps": 10, + "batches per device": 1, + "total batches observed per epoch": 1, + "train batch size": 1, + "gradient accumulation steps": 1 +} +``` + +| stage | memory_allocated | max_memory_reserved | +|:-----------------------------:|:----------------:|:-------------------:| +| after precomputing conditions | 8.88 | 8.920 | +| after precomputing latents | 9.684 | 11.613 | +| before training start | 3.809 | 10.010 | +| after epoch 1 | 4.26 | 10.916 | +| before validation start | 4.26 | 10.916 | +| after validation end | 13.924 | 17.262 | +| after training end | 4.26 | 14.314 | + +Note: requires about `17.5` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to `11` GB. +
@@ -169,8 +223,7 @@ OUTPUT_DIR="/path/to/models/hunyuan-video/hunyuan-video-loras/hunyuan-video_caki # Model arguments model_cmd="--model_name hunyuan_video \ - --pretrained_model_name_or_path tencent/HunyuanVideo - --revision refs/pr/18" + --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo" # Dataset arguments dataset_cmd="--data_root $DATA_ROOT \ @@ -252,7 +305,7 @@ import torch from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.utils import export_to_video -model_id = "tencent/HunyuanVideo" +model_id = "hunyuanvideo-community/HunyuanVideo" transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) @@ -272,10 +325,70 @@ output = pipe( export_to_video(output, "output.mp4", fps=15) ``` +### Memory Usage + +LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **without precomputation**: + +``` +Training configuration: { + "trainable parameters": 163577856, + "total samples": 69, + "train epochs": 1, + "train steps": 10, + "batches per device": 1, + "total batches observed per epoch": 69, + "train batch size": 1, + "gradient accumulation steps": 1 +} +``` + +| stage | memory_allocated | max_memory_reserved | +|:-----------------------:|:----------------:|:-------------------:| +| before training start | 38.889 | 39.020 | +| before validation start | 39.747 | 56.266 | +| after validation end | 39.748 | 58.385 | +| after epoch 1 | 39.748 | 40.910 | +| after training end | 25.288 | 40.910 | + +Note: requires about `59` GB of VRAM without precomputation. + +LoRA with rank 128, batch size 1, gradient checkpointing, optimizer adamw, `49x512x768` resolutions, **with precomputation**: + +``` +Training configuration: { + "trainable parameters": 163577856, + "total samples": 1, + "train epochs": 10, + "train steps": 10, + "batches per device": 1, + "total batches observed per epoch": 1, + "train batch size": 1, + "gradient accumulation steps": 1 +} +``` + +| stage | memory_allocated | max_memory_reserved | +|:-----------------------------:|:----------------:|:-------------------:| +| after precomputing conditions | 14.232 | 14.461 | +| after precomputing latents | 14.717 | 17.244 | +| before training start | 24.195 | 26.039 | +| after epoch 1 | 24.83 | 42.387 | +| before validation start | 24.842 | 42.387 | +| after validation end | 39.558 | 46.947 | +| after training end | 24.842 | 41.039 | + +Note: requires about `47` GB of VRAM with precomputation. If validation is not performed, the memory usage is reduced to about `42` GB. +
If you would like to use a custom dataset, refer to the dataset preparation guide [here](./assets/dataset.md). +> [!NOTE] +> To lower memory requirements: +> - Pass `--precompute_conditions` when launching training. +> - Pass `--gradient_checkpointing` when launching training. +> - Do not perform validation/testing. This saves a significant amount of memory, which can be used to focus solely on training if you're on smaller VRAM GPUs. + ## Memory requirements diff --git a/finetrainers/args.py b/finetrainers/args.py index c172028..31c2a76 100644 --- a/finetrainers/args.py +++ b/finetrainers/args.py @@ -1,6 +1,8 @@ import argparse from typing import Any, Dict, List, Optional, Tuple +import torch + from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS @@ -20,6 +22,11 @@ class Args: revision: Optional[str] = None variant: Optional[str] = None cache_dir: Optional[str] = None + text_encoder_dtype: torch.dtype = torch.bfloat16 + text_encoder_2_dtype: torch.dtype = torch.bfloat16 + text_encoder_3_dtype: torch.dtype = torch.bfloat16 + transformer_dtype: torch.dtype = torch.bfloat16 + vae_dtype: torch.dtype = torch.bfloat16 # Dataset arguments data_root: str = None @@ -32,6 +39,7 @@ class Args: video_reshape_mode: Optional[str] = None caption_dropout_p: float = 0.00 caption_dropout_technique: str = "empty" + precompute_conditions: bool = False # Dataloader arguments dataloader_num_workers: int = 0 @@ -113,6 +121,11 @@ def to_dict(self) -> Dict[str, Any]: "revision": self.revision, "variant": self.variant, "cache_dir": self.cache_dir, + "text_encoder_dtype": self.text_encoder_dtype, + "text_encoder_2_dtype": self.text_encoder_2_dtype, + "text_encoder_3_dtype": self.text_encoder_3_dtype, + "transformer_dtype": self.transformer_dtype, + "vae_dtype": self.vae_dtype, }, "dataset_arguments": { "data_root": self.data_root, @@ -124,6 +137,8 @@ def to_dict(self) -> Dict[str, Any]: "video_resolution_buckets": self.video_resolution_buckets, "video_reshape_mode": self.video_reshape_mode, "caption_dropout_p": self.caption_dropout_p, + "caption_dropout_technique": self.caption_dropout_technique, + "precompute_conditions": self.precompute_conditions, }, "dataloader_arguments": { "dataloader_num_workers": self.dataloader_num_workers, @@ -234,6 +249,11 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None: default=None, help="The directory where the downloaded models and datasets will be stored.", ) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.") + parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.") + parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.") + parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.") + parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.") def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: @@ -317,6 +337,11 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int choices=["empty", "zero"], help="Technique to use for caption dropout.", ) + parser.add_argument( + "--precompute_conditions", + action="store_true", + help="Whether or not to precompute the conditionings for the model.", + ) def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None: @@ -645,6 +670,13 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None: ) +_DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, +} + + def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args = Args() @@ -654,6 +686,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.revision = args.revision result_args.variant = args.variant result_args.cache_dir = args.cache_dir + result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype] + result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype] + result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] + result_args.transformer_dtype = _DTYPE_MAP[args.transformer_dtype] + result_args.vae_dtype = _DTYPE_MAP[args.vae_dtype] # Dataset arguments if args.data_root is None and args.dataset_file is None: @@ -668,6 +705,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS result_args.video_reshape_mode = args.video_reshape_mode result_args.caption_dropout_p = args.caption_dropout_p + result_args.caption_dropout_technique = args.caption_dropout_technique + result_args.precompute_conditions = args.precompute_conditions # Dataloader arguments result_args.dataloader_num_workers = args.dataloader_num_workers diff --git a/finetrainers/constants.py b/finetrainers/constants.py index bc050f6..26d64ff 100644 --- a/finetrainers/constants.py +++ b/finetrainers/constants.py @@ -19,6 +19,10 @@ FINETRAINERS_LOG_LEVEL = os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO") +PRECOMPUTED_DIR_NAME = "precomputed" +PRECOMPUTED_CONDITIONS_DIR_NAME = "conditions" +PRECOMPUTED_LATENTS_DIR_NAME = "latents" + MODEL_DESCRIPTION = r""" \# {model_id} {training_type} finetune diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py index 8ca0bb6..53b9378 100644 --- a/finetrainers/dataset.py +++ b/finetrainers/dataset.py @@ -1,3 +1,4 @@ +import os import random from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -19,6 +20,9 @@ decord.bridge.set_bridge("torch") +from .constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME + + logger = get_logger(__name__) @@ -257,6 +261,32 @@ def _find_nearest_resolution(self, height, width): return nearest_res[1], nearest_res[2] +class PrecomputedDataset(Dataset): + def __init__(self, data_root: str) -> None: + super().__init__() + + self.data_root = Path(data_root) + + self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME + self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME + + self.latent_conditions = sorted(os.listdir(self.latents_path)) + self.text_conditions = sorted(os.listdir(self.conditions_path)) + + assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match" + + def __len__(self) -> int: + return len(self.latent_conditions) + + def __getitem__(self, index: int) -> Dict[str, Any]: + conditions = {} + latent_path = self.latents_path / self.latent_conditions[index] + condition_path = self.conditions_path / self.text_conditions[index] + conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True) + conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True) + return conditions + + class BucketSampler(Sampler): r""" PyTorch Sampler that groups 3D data by height, width and frames. diff --git a/finetrainers/hunyuan_video/hunyuan_video_lora.py b/finetrainers/hunyuan_video/hunyuan_video_lora.py index 609f6fc..ec623d5 100644 --- a/finetrainers/hunyuan_video/hunyuan_video_lora.py +++ b/finetrainers/hunyuan_video/hunyuan_video_lora.py @@ -16,14 +16,13 @@ logger = get_logger("finetrainers") # pylint: disable=invalid-name -def load_components( - model_id: str = "tencent/HunyuanVideo", +def load_condition_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", text_encoder_dtype: torch.dtype = torch.float16, text_encoder_2_dtype: torch.dtype = torch.float16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.float16, revision: Optional[str] = None, cache_dir: Optional[str] = None, + **kwargs, ) -> Dict[str, nn.Module]: tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) text_encoder = LlamaModel.from_pretrained( @@ -35,26 +34,43 @@ def load_components( text_encoder_2 = CLIPTextModel.from_pretrained( model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir ) - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) - vae = AutoencoderKLHunyuanVideo.from_pretrained( - model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir - ) - scheduler = FlowMatchEulerDiscreteScheduler() return { "tokenizer": tokenizer, "text_encoder": text_encoder, "tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, } +def load_latent_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + vae_dtype: torch.dtype = torch.float16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir + ) + return {"vae": vae} + + +def load_diffusion_models( + model_id: str = "hunyuanvideo-community/HunyuanVideo", + transformer_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir + ) + scheduler = FlowMatchEulerDiscreteScheduler() + return {"transformer": transformer, "scheduler": scheduler} + + def initialize_pipeline( - model_id: str = "tencent/HunyuanVideo", + model_id: str = "hunyuanvideo-community/HunyuanVideo", text_encoder_dtype: torch.dtype = torch.float16, text_encoder_2_dtype: torch.dtype = torch.float16, transformer_dtype: torch.dtype = torch.bfloat16, @@ -72,6 +88,7 @@ def initialize_pipeline( enable_slicing: bool = False, enable_tiling: bool = False, enable_model_cpu_offload: bool = False, + **kwargs, ) -> HunyuanVideoPipeline: component_name_pairs = [ ("tokenizer", tokenizer), @@ -115,7 +132,7 @@ def prepare_conditions( guidance: float = 1.0, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 128, + max_sequence_length: int = 256, # TODO(aryan): make configurable prompt_template: Dict[str, Any] = { "template": ( @@ -129,6 +146,7 @@ def prepare_conditions( ), "crop_start": 95, }, + **kwargs, ) -> torch.Tensor: device = device or text_encoder.device dtype = dtype or text_encoder.dtype @@ -154,6 +172,7 @@ def prepare_latents( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, generator: Optional[torch.Generator] = None, + precompute: bool = False, **kwargs, ) -> torch.Tensor: device = device or vae.device @@ -165,9 +184,24 @@ def prepare_latents( image_or_video = image_or_video.to(device=device, dtype=vae.dtype) image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - latents = latents * vae.config.scaling_factor - latents = latents.to(dtype=dtype) + if not precompute: + latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) + latents = latents * vae.config.scaling_factor + latents = latents.to(dtype=dtype) + return {"latents": latents} + else: + if vae.use_slicing and image_or_video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] + h = torch.cat(encoded_slices) + else: + h = vae._encode(image_or_video) + return {"latents": h} + + +def post_latent_preparation( + latents: torch.Tensor, + **kwargs, +) -> torch.Tensor: return {"latents": latents} @@ -309,10 +343,13 @@ def _get_clip_prompt_embeds( HUNYUAN_VIDEO_T2V_LORA_CONFIG = { "pipeline_cls": HunyuanVideoPipeline, - "load_components": load_components, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, "initialize_pipeline": initialize_pipeline, "prepare_conditions": prepare_conditions, "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, "collate_fn": collate_fn_t2v, "forward_pass": forward_pass, "validation": validation, diff --git a/finetrainers/ltx_video/ltx_video_lora.py b/finetrainers/ltx_video/ltx_video_lora.py index f8109f9..59a121f 100644 --- a/finetrainers/ltx_video/ltx_video_lora.py +++ b/finetrainers/ltx_video/ltx_video_lora.py @@ -12,32 +12,45 @@ logger = get_logger("finetrainers") # pylint: disable=invalid-name -def load_components( +def load_condition_models( model_id: str = "Lightricks/LTX-Video", text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, revision: Optional[str] = None, cache_dir: Optional[str] = None, + **kwargs, ) -> Dict[str, nn.Module]: tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) text_encoder = T5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir ) - transformer = LTXVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + +def load_latent_models( + model_id: str = "Lightricks/LTX-Video", + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: vae = AutoencoderKLLTXVideo.from_pretrained( model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir ) + return {"vae": vae} + + +def load_diffusion_models( + model_id: str = "Lightricks/LTX-Video", + transformer_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + **kwargs, +) -> Dict[str, nn.Module]: + transformer = LTXVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir + ) scheduler = FlowMatchEulerDiscreteScheduler() - return { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "transformer": transformer, - "vae": vae, - "scheduler": scheduler, - } + return {"transformer": transformer, "scheduler": scheduler} def initialize_pipeline( @@ -114,19 +127,55 @@ def prepare_latents( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, generator: Optional[torch.Generator] = None, + precompute: bool = False, ) -> torch.Tensor: device = device or vae.device - dtype = dtype or vae.dtype if image_or_video.ndim == 4: image_or_video = image_or_video.unsqueeze(2) assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" - image_or_video = image_or_video.to(device=device, dtype=dtype) + image_or_video = image_or_video.to(device=device, dtype=vae.dtype) image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - _, _, num_frames, height, width = latents.shape - latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std) + if not precompute: + latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + _, _, num_frames, height, width = latents.shape + latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std) + latents = _pack_latents(latents, patch_size, patch_size_t) + return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} + else: + if vae.use_slicing and image_or_video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] + h = torch.cat(encoded_slices) + else: + h = vae._encode(image_or_video) + _, _, num_frames, height, width = h.shape + + # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file + # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored + # so as to reduce the disk memory requirements of the precomputed files. + return { + "latents": h, + "num_frames": num_frames, + "height": height, + "width": width, + "latents_mean": vae.latents_mean, + "latents_std": vae.latents_std, + } + + +def post_latent_preparation( + latents: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, +) -> torch.Tensor: + latents = _normalize_latents(latents, latents_mean, latents_std) latents = _pack_latents(latents, patch_size, patch_size_t) return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} @@ -260,10 +309,13 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int LTX_VIDEO_T2V_LORA_CONFIG = { "pipeline_cls": LTXPipeline, - "load_components": load_components, + "load_condition_models": load_condition_models, + "load_latent_models": load_latent_models, + "load_diffusion_models": load_diffusion_models, "initialize_pipeline": initialize_pipeline, "prepare_conditions": prepare_conditions, "prepare_latents": prepare_latents, + "post_latent_preparation": post_latent_preparation, "collate_fn": collate_fn_t2v, "forward_pass": forward_pass, "validation": validation, diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 71e4d60..f98393c 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -29,20 +29,27 @@ compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, ) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from diffusers.utils import export_to_video, load_image, load_video from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm from .args import Args, validate_args -from .constants import FINETRAINERS_LOG_LEVEL -from .dataset import BucketSampler, VideoDatasetWithResizing +from .constants import ( + FINETRAINERS_LOG_LEVEL, + PRECOMPUTED_DIR_NAME, + PRECOMPUTED_CONDITIONS_DIR_NAME, + PRECOMPUTED_LATENTS_DIR_NAME, +) +from .dataset import BucketSampler, PrecomputedDataset, VideoDatasetWithResizing from .models import get_config_from_model_name from .state import State +from .utils.data_utils import should_perform_precomputation from .utils.file_utils import find_files, delete_files, string_to_filename from .utils.optimizer_utils import get_optimizer, gradient_norm from .utils.memory_utils import get_memory_statistics, free_memory, make_contiguous -from .utils.torch_utils import unwrap_model +from .utils.torch_utils import unwrap_model, align_device_and_dtype logger = get_logger("finetrainers") @@ -73,6 +80,9 @@ def __init__(self, args: Args) -> None: # Autoencoders self.vae = None + # Scheduler + self.scheduler = None + self._init_distributed() self._init_logging() self._init_directories_and_repositories() @@ -80,28 +90,82 @@ def __init__(self, args: Args) -> None: self.state.model_name = self.args.model_name self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type) - def prepare_models(self) -> None: - logger.info("Initializing models") + def prepare_dataset(self) -> None: + # TODO(aryan): Make a background process for fetching + logger.info("Initializing dataset and dataloader") - # TODO(aryan): refactor in future - load_components_kwargs = { - "text_encoder_dtype": torch.bfloat16, - "transformer_dtype": torch.bfloat16, - "vae_dtype": torch.bfloat16, + self.dataset = VideoDatasetWithResizing( + data_root=self.args.data_root, + caption_column=self.args.caption_column, + video_column=self.args.video_column, + resolution_buckets=self.args.video_resolution_buckets, + dataset_file=self.args.dataset_file, + id_token=self.args.id_token, + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=1, + sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), + collate_fn=self.model_config.get("collate_fn"), + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.pin_memory, + ) + + def _get_load_components_kwargs(self) -> Dict[str, Any]: + load_component_kwargs = { + "text_encoder_dtype": self.args.text_encoder_dtype, + "text_encoder_2_dtype": self.args.text_encoder_2_dtype, + "text_encoder_3_dtype": self.args.text_encoder_3_dtype, + "transformer_dtype": self.args.transformer_dtype, + "vae_dtype": self.args.vae_dtype, "revision": self.args.revision, "cache_dir": self.args.cache_dir, } if self.args.pretrained_model_name_or_path is not None: - load_components_kwargs["model_id"] = self.args.pretrained_model_name_or_path - components = self._model_config_call(self.model_config["load_components"], load_components_kwargs) + load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path + return load_component_kwargs + + def _set_components(self, components: Dict[str, Any]) -> None: + self.tokenizer = components.get("tokenizer", self.tokenizer) + self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2) + self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3) + self.text_encoder = components.get("text_encoder", self.text_encoder) + self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2) + self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3) + self.transformer = components.get("transformer", self.transformer) + self.unet = components.get("unet", self.unet) + self.vae = components.get("vae", self.vae) + self.scheduler = components.get("scheduler", self.scheduler) + + def _delete_components(self) -> None: + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + self.transformer = None + self.unet = None + self.vae = None + self.scheduler = None + free_memory() + torch.cuda.synchronize(self.state.accelerator.device) + + def prepare_models(self) -> None: + logger.info("Initializing models") - self.tokenizer = components.get("tokenizer", None) - self.text_encoder = components.get("text_encoder", None) - self.tokenizer_2 = components.get("tokenizer_2", None) - self.text_encoder_2 = components.get("text_encoder_2", None) - self.transformer = components.get("transformer", None) - self.vae = components.get("vae", None) - self.scheduler = components.get("scheduler", None) + load_components_kwargs = self._get_load_components_kwargs() + condition_components, latent_components, diffusion_components = {}, {}, {} + if not self.args.precompute_conditions: + condition_components = self.model_config["load_condition_models"](**load_components_kwargs) + latent_components = self.model_config["load_latent_models"](**load_components_kwargs) + diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs) + + components = {} + components.update(condition_components) + components.update(latent_components) + components.update(diffusion_components) + self._set_components(components) if self.vae is not None: if self.args.enable_slicing: @@ -111,22 +175,168 @@ def prepare_models(self) -> None: self.transformer_config = self.transformer.config if self.transformer is not None else None - def prepare_dataset(self) -> None: - logger.info("Initializing dataset and dataloader") + def prepare_precomputations(self) -> None: + if not self.args.precompute_conditions: + return - self.dataset = VideoDatasetWithResizing( - data_root=self.args.data_root, - caption_column=self.args.caption_column, - video_column=self.args.video_column, - resolution_buckets=self.args.video_resolution_buckets, - dataset_file=self.args.dataset_file, - id_token=self.args.id_token, + logger.info("Initializing precomputations") + + if self.args.batch_size != 1: + raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.") + + def collate_fn(batch): + latent_conditions = [x["latent_conditions"] for x in batch] + text_conditions = [x["text_conditions"] for x in batch] + batched_latent_conditions = {} + batched_text_conditions = {} + for key in list(latent_conditions[0].keys()): + if torch.is_tensor(latent_conditions[0][key]): + batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0) + else: + # TODO(aryan): implement batch sampler for precomputed latents + batched_latent_conditions[key] = [x[key] for x in latent_conditions][0] + for key in list(text_conditions[0].keys()): + if torch.is_tensor(text_conditions[0][key]): + batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0) + else: + # TODO(aryan): implement batch sampler for precomputed latents + batched_text_conditions[key] = [x[key] for x in text_conditions][0] + return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions} + + should_precompute = should_perform_precomputation(self.args.data_root) + if not should_precompute: + logger.info("Precomputed conditions and latents found. Loading precomputed data.") + self.dataloader = torch.utils.data.DataLoader( + PrecomputedDataset(self.args.data_root), + batch_size=self.args.batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.pin_memory, + ) + return + + logger.info("Precomputed conditions and latents not found. Running precomputation.") + + # At this point, no models are loaded, so we need to load and precompute conditions and latents + condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs()) + self._set_components(condition_components) + self._move_components_to_device() + + # TODO(aryan): refactor later. for now only lora is supported + components_to_disable_grads = [ + self.text_encoder, + self.text_encoder_2, + self.text_encoder_3, + ] + for component in components_to_disable_grads: + if component is not None: + component.requires_grad_(False) + + if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty": + logger.warning( + "Caption dropout is not supported with precomputation yet. This will be supported in the future." + ) + + conditions_dir = Path(self.args.data_root) / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = Path(self.args.data_root) / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME + conditions_dir.mkdir(parents=True, exist_ok=True) + latents_dir.mkdir(parents=True, exist_ok=True) + + # Precompute conditions + progress_bar = tqdm( + range(0, len(self.dataset)), + desc="Precomputing conditions", + disable=not self.state.accelerator.is_local_main_process, ) + index = 0 + for i, data in enumerate(self.dataset): + if i % self.state.accelerator.num_processes != self.state.accelerator.process_index: + continue + + logger.debug( + f"Precomputing conditions and latents for batch {i + 1}/{len(self.dataset)} on process {self.state.accelerator.process_index}" + ) + + text_conditions = self.model_config["prepare_conditions"]( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + text_encoder_3=self.text_encoder_3, + prompt=data["prompt"], + device=self.state.accelerator.device, + dtype=self.state.weight_dtype, + ) + filename = conditions_dir / f"conditions-{i}-{index}.pt" + torch.save(text_conditions, filename.as_posix()) + index += 1 + progress_bar.update(1) + self._delete_components() + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(self.state.accelerator.device) + + # Precompute latents + latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs()) + self._set_components(latent_components) + self._move_components_to_device() + + # TODO(aryan): refactor later + components_to_disable_grads = [self.vae] + for component in components_to_disable_grads: + if component is not None: + component.requires_grad_(False) + + if self.vae is not None: + if self.args.enable_slicing: + self.vae.enable_slicing() + if self.args.enable_tiling: + self.vae.enable_tiling() + + progress_bar = tqdm( + range(0, len(self.dataset)), + desc="Precomputing latents", + disable=not self.state.accelerator.is_local_main_process, + ) + index = 0 + for i, data in enumerate(self.dataset): + if i % self.state.accelerator.num_processes != self.state.accelerator.process_index: + continue + + logger.debug( + f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {self.state.accelerator.process_index}" + ) + + latent_conditions = self.model_config["prepare_latents"]( + vae=self.vae, + image_or_video=data["video"].unsqueeze(0), + device=self.state.accelerator.device, + dtype=self.state.weight_dtype, + generator=self.state.generator, + precompute=True, + ) + filename = latents_dir / f"latents-{self.state.accelerator.process_index}-{index}.pt" + torch.save(latent_conditions, filename.as_posix()) + index += 1 + progress_bar.update(1) + self._delete_components() + + self.state.accelerator.wait_for_everyone() + logger.info("Precomputation complete") + + memory_statistics = get_memory_statistics() + logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(self.state.accelerator.device) + + # Update dataloader to use precomputed conditions and latents self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=1, - sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), - collate_fn=self.model_config.get("collate_fn"), + PrecomputedDataset(self.args.data_root), + batch_size=self.args.batch_size, + shuffle=True, + collate_fn=collate_fn, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory, ) @@ -134,13 +344,20 @@ def prepare_dataset(self) -> None: def prepare_trainable_parameters(self) -> None: logger.info("Initializing trainable parameters") - # TODO(aryan): refactor later. for now only lora is supported - self.text_encoder.requires_grad_(False) - self.transformer.requires_grad_(False) - self.vae.requires_grad_(False) + diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs()) + self._set_components(diffusion_components) - if self.text_encoder_2 is not None: - self.text_encoder_2.requires_grad_(False) + # TODO(aryan): refactor later. for now only lora is supported + components_to_disable_grads = [ + self.text_encoder, + self.text_encoder_2, + self.text_encoder_3, + self.transformer, + self.vae, + ] + for component in components_to_disable_grads: + if component is not None: + component.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -158,12 +375,8 @@ def prepare_trainable_parameters(self) -> None: # TODO(aryan): handle torch dtype from accelerator vs model dtype; refactor self.state.weight_dtype = weight_dtype - self.text_encoder.to(self.state.accelerator.device, dtype=weight_dtype) - self.transformer.to(self.state.accelerator.device, dtype=weight_dtype) - self.vae.to(self.state.accelerator.device, dtype=weight_dtype) - - if self.text_encoder_2 is not None: - self.text_encoder_2.to(self.state.accelerator.device, dtype=weight_dtype) + self.transformer.to(dtype=weight_dtype) + self._move_components_to_device() if self.args.gradient_checkpointing: self.transformer.enable_gradient_checkpointing() @@ -364,44 +577,56 @@ def train(self) -> None: logs = {} with accelerator.accumulate(models_to_accumulate): - videos = batch["videos"] - prompts = batch["prompts"] - batch_size = len(prompts) - - if self.args.caption_dropout_technique == "empty": - if random.random() < self.args.caption_dropout_p: - prompts = [""] * batch_size + if not self.args.precompute_conditions: + videos = batch["videos"] + prompts = batch["prompts"] + batch_size = len(prompts) + + if self.args.caption_dropout_technique == "empty": + if random.random() < self.args.caption_dropout_p: + prompts = [""] * batch_size + + latent_conditions = self.model_config["prepare_latents"]( + vae=self.vae, + image_or_video=videos, + patch_size=self.transformer_config.patch_size, + patch_size_t=self.transformer_config.patch_size_t, + device=accelerator.device, + dtype=weight_dtype, + generator=generator, + ) + text_conditions = self.model_config["prepare_conditions"]( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + tokenizer_2=self.tokenizer_2, + text_encoder_2=self.text_encoder_2, + prompt=prompts, + device=accelerator.device, + dtype=weight_dtype, + ) + else: + latent_conditions = batch["latent_conditions"] + text_conditions = batch["text_conditions"] + latent_conditions["latents"] = DiagonalGaussianDistribution( + latent_conditions["latents"] + ).sample(generator) + if "post_latent_preparation" in self.model_config.keys(): + latent_conditions = self.model_config["post_latent_preparation"](**latent_conditions) + align_device_and_dtype(latent_conditions, accelerator.device, weight_dtype) + align_device_and_dtype(text_conditions, accelerator.device, weight_dtype) + batch_size = latent_conditions["latents"].shape[0] - latent_conditions = self.model_config["prepare_latents"]( - vae=self.vae, - image_or_video=videos, - patch_size=self.transformer_config.patch_size, - patch_size_t=self.transformer_config.patch_size_t, - device=accelerator.device, - dtype=weight_dtype, - generator=generator, - ) latent_conditions = make_contiguous(latent_conditions) - - other_conditions = self.model_config["prepare_conditions"]( - tokenizer=self.tokenizer, - text_encoder=self.text_encoder, - tokenizer_2=self.tokenizer_2, - text_encoder_2=self.text_encoder_2, - prompt=prompts, - device=accelerator.device, - dtype=weight_dtype, - ) - other_conditions = make_contiguous(other_conditions) + text_conditions = make_contiguous(text_conditions) if self.args.caption_dropout_technique == "zero": if random.random() < self.args.caption_dropout_p: - other_conditions["prompt_embeds"].fill_(0) - other_conditions["prompt_attention_mask"].fill_(False) + text_conditions["prompt_embeds"].fill_(0) + text_conditions["prompt_attention_mask"].fill_(False) # TODO(aryan): refactor later - if "pooled_prompt_embeds" in other_conditions: - other_conditions["pooled_prompt_embeds"].fill_(0) + if "pooled_prompt_embeds" in text_conditions: + text_conditions["pooled_prompt_embeds"].fill_(0) # These weighting schemes use a uniform timestep sampling and instead post-weight the loss weights = compute_density_for_timestep_sampling( @@ -424,14 +649,13 @@ def train(self) -> None: noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise latent_conditions.update({"noisy_latents": noisy_latents}) - other_conditions.update({"timesteps": timesteps}) # These weighting schemes use a uniform timestep sampling and instead post-weight the loss weights = compute_loss_weighting_for_sd3( weighting_scheme=self.args.flow_weighting_scheme, sigmas=sigmas ) pred = self.model_config["forward_pass"]( - transformer=self.transformer, **latent_conditions, **other_conditions + transformer=self.transformer, timesteps=timesteps, **latent_conditions, **text_conditions ) target = noise - latent_conditions["latents"] @@ -443,7 +667,8 @@ def train(self) -> None: accelerator.backward(loss) if accelerator.sync_gradients and accelerator.distributed_type != DistributedType.DEEPSPEED: - accelerator.clip_grad_norm_(self.transformer.parameters(), self.args.max_grad_norm) + grad_norm = accelerator.clip_grad_norm_(self.transformer.parameters(), self.args.max_grad_norm) + logs["grad_norm"] = grad_norm self.optimizer.step() self.lr_scheduler.step() @@ -588,6 +813,10 @@ def validate(self, step: int) -> None: generator=self.state.generator, ) + # Remove all hooks that might have been added during pipeline initialization to the models + pipeline.remove_all_hooks() + del pipeline + prompt_filename = string_to_filename(prompt)[:25] artifacts = { "image": {"type": "image", "value": image}, @@ -629,9 +858,12 @@ def validate(self, step: int) -> None: tracker.log({"validation": all_artifacts}, step=step) accelerator.wait_for_everyone() + free_memory() memory_statistics = get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(accelerator.device) + self.transformer.train() def evaluate(self) -> None: @@ -692,7 +924,16 @@ def _init_directories_and_repositories(self) -> None: repo_id = self.args.hub_model_id or Path(self.args.output_dir).name self.state.repo_id = create_repo(token=self.args.hub_token, name=repo_id).repo_id - def _model_config_call(self, fn, kwargs): - accepted_kwargs = inspect.signature(fn).parameters.keys() - kwargs = {k: v for k, v in kwargs.items() if k in accepted_kwargs} - return fn(**kwargs) + def _move_components_to_device(self): + if self.text_encoder is not None: + self.text_encoder = self.text_encoder.to(self.state.accelerator.device) + if self.text_encoder_2 is not None: + self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device) + if self.text_encoder_3 is not None: + self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device) + if self.transformer is not None: + self.transformer = self.transformer.to(self.state.accelerator.device) + if self.unet is not None: + self.unet = self.unet.to(self.state.accelerator.device) + if self.vae is not None: + self.vae = self.vae.to(self.state.accelerator.device) diff --git a/finetrainers/utils/data_utils.py b/finetrainers/utils/data_utils.py new file mode 100644 index 0000000..b05c49e --- /dev/null +++ b/finetrainers/utils/data_utils.py @@ -0,0 +1,35 @@ +from pathlib import Path +from typing import Union + +from accelerate.logging import get_logger + +from ..constants import PRECOMPUTED_DIR_NAME, PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME + + +logger = get_logger("finetrainers") + + +def should_perform_precomputation(data_root: Union[str, Path]) -> bool: + if isinstance(data_root, str): + data_root = Path(data_root) + conditions_dir = data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME + latents_dir = data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME + if conditions_dir.exists() and latents_dir.exists(): + num_files_conditions = len(list(conditions_dir.glob("*.pt"))) + num_files_latents = len(list(latents_dir.glob("*.pt"))) + if num_files_conditions != num_files_latents: + logger.warning( + f"Number of precomputed conditions ({num_files_conditions}) does not match number of precomputed latents ({num_files_latents})." + f"Cleaning up precomputed directories and re-running precomputation." + ) + # clean up precomputed directories + for file in conditions_dir.glob("*.pt"): + file.unlink() + for file in latents_dir.glob("*.pt"): + file.unlink() + return True + if num_files_conditions > 0: + logger.info(f"Found {num_files_conditions} precomputed conditions and latents.") + return False + logger.info("Precomputed data not found. Running precomputation.") + return True diff --git a/finetrainers/utils/torch_utils.py b/finetrainers/utils/torch_utils.py index 32190bb..2aad077 100644 --- a/finetrainers/utils/torch_utils.py +++ b/finetrainers/utils/torch_utils.py @@ -1,3 +1,6 @@ +from typing import Dict, Optional, Union + +import torch from accelerate import Accelerator from diffusers.utils.torch_utils import is_compiled_module @@ -6,3 +9,21 @@ def unwrap_model(accelerator: Accelerator, model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x diff --git a/train.py b/train.py index 0a45987..e12d47c 100644 --- a/train.py +++ b/train.py @@ -27,6 +27,7 @@ def main(): trainer.prepare_dataset() trainer.prepare_models() + trainer.prepare_precomputations() trainer.prepare_trainable_parameters() trainer.prepare_optimizer() trainer.prepare_for_training()