Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factor out the Trainer class #130

Closed
wants to merge 4 commits into from
Closed

Factor out the Trainer class #130

wants to merge 4 commits into from

Conversation

sayakpaul
Copy link
Collaborator

@sayakpaul sayakpaul commented Dec 20, 2024

@a-r-r-o-w this is how I am envisioning the Trainer class.

To test:

Collapse
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0
export FINETRAINERS_LOG_LEVEL=DEBUG

GPU_IDS="0,1"

DATA_ROOT="/home/sayak/finetrainers/video-dataset-disney"
CAPTION_COLUMN="prompt.txt"
VIDEO_COLUMN="videos.txt"
OUTPUT_DIR="/raid/.cache/huggingface/sayak/ltx-video/ltxv_disney"

# Model arguments
model_cmd="--model_name ltx_video \
  --pretrained_model_name_or_path Lightricks/LTX-Video"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token BW_STYLE \
  --video_resolution_buckets 49x512x768 \
  --caption_dropout_p 0.05"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 4"

# Diffusion arguments
diffusion_cmd="--flow_resolution_shifting"

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --mixed_precision bf16 \
  --batch_size 1 \
  --train_steps 1200 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 3e-5 \
  --lr_scheduler constant_with_warmup \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"afkx A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x512x768:::A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage@@@49x512x768\" \
  --num_validation_videos 1 \
  --validation_steps 100"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-ltxv \
  --output_dir $OUTPUT_DIR \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_2.yaml --gpu_ids $GPU_IDS train_ltx.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

Of course, this will break HunyuanVideo, so, I am happy to

  • first create another branch called dev, merge the current PR into that.
  • open another PR for HunyuanVideo and merge that into dev.

Now, we should be safe to merge dev into main.

Training logs: https://wandb.ai/sayakpaul/finetrainers-ltxv/runs/qcqbqjxq

WDYT?

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 20, 2024 12:07
Copy link
Owner

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just some concerns and ideas for first draft.

I believe we're going to have some merge conflicts from #129 and would prefer that be merged first because of significant memory savings from precomputation. Happy to help resolve any merge conflicts and proceed further here

import os
import sys

base_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't prefer this. train.py should be the entry point and determine which trainer to use based on provided arguments. So prefer to not having any changes to import style.



class LTXTrainer(Trainer):
def prepare_models(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use this approach as shown in #129 because then we won't be able to save memory from precomputation. Happy to help resolve the merge conflicts and modify the implementation here to reach the same memory levels as #129

from tqdm import tqdm

from .args import Args, validate_args
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not import it this way by changing base paths through os and sys modules. I prefer the same way it's done in diffusers with relative imports but if there's benefit to doing it this way, could you help me understand?

def calculate_loss_weights(self, **kwargs):
raise NotImplementedError

def sort_out_checkpoint_to_resume_from(self, accelerator):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe make this a helper method in checkpoint_utils.py and pass all the relevant arguments

path = dirs[-1] if len(dirs) > 0 else None

if path is None:
accelerator.print(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use accelerator.print. logger.info should work, and even if this is moved to a different file, get_logger("finetrainers") will allow us to perform the logging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was here before. I did not use print.


return initial_global_step, global_step, first_epoch

def save_intermediate_checkpoint(self, step, accelerator):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also move this into a utils file making sure to pass all required arguments and import here

accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

def save_final_checkpoint(self, accelerator):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LoRA and full finetuning have different ways of handling this, so would make utility functions here too and import. For now, since we only have lora, this is okay to keep here too

@a-r-r-o-w a-r-r-o-w mentioned this pull request Dec 22, 2024
2 tasks
@sayakpaul
Copy link
Collaborator Author

Okay, please proceed with merging #129 first. Would be nice to first get headsup about PRs like #129 to meaningfully streamline efforts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants