Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precomputation of conditions and latents #129

Merged
merged 12 commits into from
Dec 23, 2024
Merged

Conversation

a-r-r-o-w
Copy link
Owner

No description provided.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review December 21, 2024 14:56
@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul December 21, 2024 14:57
@sayakpaul
Copy link
Collaborator

Actually I think I have an idea which I would like to consider before this is merged. I am on the move currently but get the comments in by today before evening.

@a-r-r-o-w
Copy link
Owner Author

Sounds good, LMK when you're free!

README.md Show resolved Hide resolved
finetrainers/ltx_video/ltx_video_lora.py Outdated Show resolved Hide resolved
finetrainers/trainer.py Outdated Show resolved Hide resolved
finetrainers/ltx_video/ltx_video_lora.py Show resolved Hide resolved
finetrainers/trainer.py Show resolved Hide resolved
finetrainers/trainer.py Show resolved Hide resolved
finetrainers/trainer.py Show resolved Hide resolved
finetrainers/trainer.py Outdated Show resolved Hide resolved
finetrainers/utils/torch_utils.py Show resolved Hide resolved
@sayakpaul
Copy link
Collaborator

Not for this PR but, usually, for large-scale training projects, any kind of precomputation is treated as "data processing" and is separated from trainer classes.

So, in this case, as a (power) user, I would prefer:

  • Functional APIs for each of the pipelines that performs that necessary precomputions.
  • Scripts under a directory called tools that shows how to use the APIs to serialize precomputed objects. Then we should be ready to do training with the corresponding Trainer class. Since args will have a output_dir, it's easy to discover the serialized objects from precomputation.
  • This will allow us to have a greater degree of flexibility in how we configure and utilize accelerators based on the compute patterns of the precomputation stage, which can be different from training.

But for users to get started, the current way of handling things is fine. If we do implement the functional APIs, trainer classes could implement a precompute_conditions() method (like now) utilizing the functional APIs.

I also don't like train.py as the main entrypoint to train models given each trainer would be separate in classes and will implement their own atomic modules. Refer to https://huggingface.co/docs/trl/example_overview as an example. For now, train.py is fine, though. In short, each trainer implements an example on its own to keep things very clear and friendly to refer to. Example: https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 23, 2024

any kind of precomputation is treated as "data processing" and is separated from trainer classes.

I respectfully disagree. I may not have the experience you have, but as far as our initial goals were aimed towards, we wanted to make video model training more accessible, not target large-scale training. Adding extra steps and more things that users need to follow instructions for it to work is a big no-go IMO. Yes, we can eventually try and be more ambitious but for now, we should focus on streamlining things as much as possible and supporting all things lora (normal lora, control lora, turbo distillation loras, etc.) and make sure our potato friends can run it.

From some feedback I gathered, people who currently use this project want two things:

  • Create a videos.txt and prompts.txt file
  • Click some fancy switches and knobs to tweak hyperparameters with, and press a button to start finetuning with.

Ofcourse, this is not a UI, so the button part is not to be considered, and is a battle for someone else ot fight. But opt-in precomputation like we do here from just two text files looks ideal to me (for now). We can reconsider in future as you mention.

Me and you know about things and can work our way through all this, but I want you to think in the way of a general user with no programming knowledge, who is just experimenting from a creator background, has read some tutorials online and just wants to use the sensible defaults, that they will find all of this to be non-sense. Just from the history of diffusion inference in the past two years, I believe there's a lot of lessons to be learnt on what got most popular among users who actually use these kinds of tools daily.

Large scale training is quite a ways away from what we have. Basic data parallelism like we have will never scale IMO, so first, that is what we should tackle before designing anything else if that's what we want to prioritize.

Functional APIs for each of the pipelines that performs that necessary precomputions.

Awesome, so this is where we are at currently. We just need to "API"-fy it a bit more so that custom methods/callbacks can be passed in, or what we have currently can be more re-usable easily

This will allow us to have a greater degree of flexibility in how we configure and utilize accelerators based on the compute patterns of the precomputation stage, which can be different from training.

Utilizing accelerators based on compute patterns of precomputation stage affect training...? I don't quite follow what you mean. Maybe we can hop in DM and I can try to understand?

I also don't like train.py as the main entrypoint to train models given each trainer would be separate in classes and will implement their own atomic modules. Refer to https://huggingface.co/docs/trl/example_overview as an example. For now, train.py is fine, though. In short, each trainer implements an example on its own to keep things very clear and friendly to refer to. Example: https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py.

I would like to distinguish between two things here:

  • Algorithm-specific training (LoRA vs Full finetuning vs Control LoRA vs ControlNet vs Distillation vs etc.)
  • Model-specific training (LTXV vs Hunyuan vs Cog vs etc.)

TRL implements what I would call "algorithm-specific" training. It works with different models because many things are common and transformers provides some great abstractions, but also because it is mostly just a single model and not many moving components. The same does not apply to diffusion models, as we have discussed time and again how many moving components are involved, and we will most definitely require model-specific implementations because different models process things differently.

But OTOH, most underlying training algorithms are independent of the modeling architecture itself. For example, with a decent abstraction like what we have now, lora training works for both LTX and Hunyuan without any lora stuff interleaving with the single file implementations we have regarding model loading, condition and latent computation, etc. I believe, similar would be the case for full finetuning, distillation, and pretty much everything else (we can break away into better abstraction when we get to it, but let's not prematurely design it with that in mind because it will only slow us down).

I would prefer doing the following:

  • Do thing X
  • If it works without modifying current implementation too much or not at all, good. Proceed with it
  • If it requires substantial modifications, perform the refactor.

With that in mind, we can have multiple entry-points, but they should not be on "model" level and should instead be on "algorithm" level. The model to be trained and all the underlying configuration managements should happend with a single file such as train_lora.py, train_control_lora.py, train_guidance_distillation.py, etc.

These are my thoughts on how we should proceed to maximize our speed of integrating new models and techniques that we want to support. Premature refactoring might bring us abstractions that look good now, but might slow us down and cause further required refactoring in future.

LMK what you think so we can proceed accordingly. If you think we absolutely need to do a better design first before moving forward, I'm happy to assist and will not make any further PRs until we get that sorted - but if not, I can take up CogVideoX next and maybe you could do Mochi?

I will address your concerns shortly and push the relevant changes. Thanks for the reviews!

@sayakpaul
Copy link
Collaborator

Utilizing accelerators based on compute patterns of precomputation stage affect training...? I don't quite follow what you mean. Maybe we can hop in DM and I can try to understand?

I didn't say "affect". I am saying it gives us a chance to better utilize accelerators in the way we want to. Accelerator distribution for precomputation isn't always necessarily the same as training. Below is an example:

  • Use first the GPU for text embeddings.
  • The rest of the two for latent computation.

(both happening in parallel)

This (the way accelerators are being utilized) is not what we do during training.

But okay to not proceed in the direction I mentioned. But I won't absolutely prefer having a single Trainer class if that is what you absolutely want. For every model, I do envision a separate Trainer class because maintenance-wise, it's better and easier to maintain. And if we cannot agree on this design, then this needs further discussion.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 23, 2024

But okay to not proceed in the direction I mentioned. But I won't absolutely prefer having a single Trainer class if that is what you absolutely want

What is okay with me is:

  • Separate class per algorithm. So one for lora, one for full finetuning, one for task X, one for task Y, etc.

What is not okay with me:

  • One for LTX LoRA, one for LTX Full finetuning, one for Mochi Control LoRA, one for Mochi normal LoRA, etc. This will be a mess
  • Clubbing all of LoRA, Full, Control, Distillation under each "model" class and having it exist per-model trainer

Preparing conditions and latents is a simple enough thing IME that won't require an implementation for each kind of algorithm and can generally be re-used.

I would also not be okay if each trainer class had to implement their own N different "algorithm" implementations. Basically, I don't want LTXTrainer to have its own implementation for all the different mentioned tasks unless absolutely necessary. Doing so will only create bugs and make it harder to maintain. If it were rather called something like LTXTrainingPipeline which implemented all the necessary helper methods, it would be better for me - these should have absolutely no training related things intermixed with it and should be completely decoupled. The one thing that would be okay to not decouple from training would be loss computation, because we know there are two broad categories of models that exist today from diffusion perspective. But again, it's maybe too early for coupling in loss, because it is easy to deduce what loss a model requires just by looking at the Scheduler being used.

What would be more likeable to me is LoRATrainer, FullTrainer, GuidanceDistillationTrainer, etc. All the other underlying parts from model-specific perspective can be re-used across this IMO.

@sayakpaul
Copy link
Collaborator

sayakpaul commented Dec 23, 2024

So, this means you're not okay with the direction taken in #130, right?

What would be more likeable to me is LoRATrainer, FullTrainer, GuidanceDistillationTrainer, etc.

Very minimal differences between LoRATrainer and FullTrainer, though. So, I don't see the point.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 23, 2024

Very minimal differences between LoRATrainer and `FullTrainer, though. So, I don't see the point.

Okay, I feel like despite my efforts, it is extremely hard to communicate my point across. Please make an effort to not pick two of the simplest examples and extrapolate it to the entirety of the message. Do we not agree that LoRA is a different training technique than full finetuning? I am trying to discuss about training algorithms, not how similar/dissimilar they are.

My point is, we should have trainers specific to algorithms. Not models. It should be model-independant as long as the model provides all the helpful utility functions required for training. Let's not tie in the two together - this inevitably leads to bugs and I've seen my fair share with the diffusers scripts that do it model-specifically. It is hard to use and unstable, but it serves its purpose in that it allows for hacking and exploration.

Here are the directions I'm not okay with in #130:

  • Dataset loading tied into every trainer. I don't understand why this would be required. Every model takes in the same format of data for a given task. Happy to understand from an example that shows otherwise
  • Trainable parameter preparation tied into every trainer. This should be done at one level higher up the abstraction chain depending on the task being trained irrespective of the model used unless absolutely required.
  • Loss calculation. There are only two kinds of losses with diffusion models in regards to lora training, which is all we support for now. DDPM and Flow matching. Anything else should be an augmentation loss, which we could be used in any other model. These should be enabled by training flags. However, we are not at the stage where we could consider things like this (for example, motion regularization or latent perceptual loss), so it is okay for now. Unless a paper specifically introduces a loss for a model, and that loss is not compatible with other models, we should use the normal base losses.

To reiterate, the models should not tie in with the actual training algorithm in any way. In all trainers I've looked at, those that do this have very complicated code (but partially also because they build atop the LDM codebase). SimpleTuner is a good example that does not tie in the loss to the model itself, and to me is the most successful training project built on Diffusers because it "just works" OOTB and you don't have to think much to get it to work. I would take inspiration from there and follow some design patterns, but improve wherever possible and customize to our broader set of goals.

@sayakpaul
Copy link
Collaborator

sayakpaul commented Dec 23, 2024

It's okay to have disagreements and I can see some of your points. So, I will roll with this direction. Making changes later isn't a big deal.

The reason why I wanted to separate dataset prep was because not all models have the same components we need during precomputation (different text encoders, different masking, etc.). So, unifying that under a single method becomes very messy.

And then for caption dropout, different models do it differently as mentioned the other day. It's easy to miss this detail but it plays a crucial role, sometimes.

For loss, Mochi does it differently than the other flow models we have.

TL;DR -- let's get this PR merged. I will work on Cog next as explained over DM.

@a-r-r-o-w
Copy link
Owner Author

The reason why I wanted to separate dataset prep was because not all models have the same components we need during precomputation (different text encoders, different masking, etc.). So, unifying that under a single method becomes very messy.

This, I agree with. What I want is for the utility functions to be part of the model-specific implementations. Not the loading and dataloader creation itself. That should be a separate decoupled component in itself. We leverage the utility functions with conditional in the trainer (just throwing out ideas):

text_conditions = self.model_trainer_pipeline.prepare_conditions(...)
# prepare conditions will get a one or more prompts based on some configured batch
# size, and prepare everything that it requires using multiple text encoders if needed.
# For example, I already do this without much complications for Hunyuan which uses
# 2 text encoders

And then for caption dropout, different models do it differently as mentioned the other day. It's easy to miss this detail but it plays a crucial role, sometimes.

Again, this is a trainer-specific detail and should not be tied into the model utilities itself. The trainer will invoke preparation methods. If it's an "empty"-ing dropout, null prompt will be passed to the preparation method. If it's a "tensor zero"-ing dropout, it can be done post-preparation. We already do this. This should be configurable via CLI args to enable exploration of what works, not hardcode it for certain models. We can provide sensible guidelines on what to do per-model.

For loss, Mochi does it differently than the other flow models we have.

It's one case among the thousands of flow-matching models. It's not a different flow matching algorithm as such. It's just a mistake that occurred in the original codebase, as I like to believe. Instead of the usual (timestep 1000 = noise) and (timestep = clean), they have a model trained with an inverse objective.

TL;DR -- let's get this PR merged. I will work on Cog next as explained over DM.

Thank you. I realize I might have spoken a bit in spirit and strongly about the way I expect moving forward. I think it is good to have these discussions in early stages before realizing it later. Working my way through the reviews and will push shortly

@sayakpaul
Copy link
Collaborator

Again, this is a trainer-specific detail and should not be tied into the model utilities itself. The trainer will invoke preparation methods. If it's an "empty"-ing dropout, null prompt will be passed to the preparation method. If it's a "tensor zero"-ing dropout, it can be done post-preparation. We already do this. This should be configurable via CLI args to enable exploration of what works, not hardcode it for certain models. We can provide sensible guidelines on what to do per-model.

Just so you know it's can be model-specific :D If a model was pre-trained using a particular caption-dropping method it will have to be accounted for. So, we can't eliminate possibility and it's something a user should be mindful about. By default, we shouldn't do any caption dropping out hence. I guess, that is fair?

@a-r-r-o-w
Copy link
Owner Author

Just so you know it's can be model-specific :D If a model was pre-trained using a particular caption-dropping method it will have to be accounted for. So, we can't eliminate possibility and it's something a user should be mindful about. By default, we shouldn't do any caption dropping out hence. I guess, that is fair?

Very rarely it's the case that an empty/null prompt is not used. Maybe some of the recent ones, and for these cases, when the trainer object is initialized, I would like a warning be raised than disabling the possibility of doing anything else altogether.

# Somewhere in the model utility code
if model_name == "mochi" and args.caption_dropout_technique == "empty":
    logger.warning("You did bad thing. Maybe don't do it.")

Just like I said above, let's get to it when we get to it instead of thinking about all this. It only slows us down having pointless debates. None of the models currently require us to diverge from the sensible defaults

@sayakpaul
Copy link
Collaborator

Okay. I was trying to be thorough about the issues that can silently stem in training but it seems like we don't have to be worried about them that much for now. Let's not call that a pointless debate.

@a-r-r-o-w a-r-r-o-w requested a review from sayakpaul December 23, 2024 13:17
Copy link
Collaborator

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor nits. Otherwise, looks good.

README.md Show resolved Hide resolved
README.md Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
finetrainers/args.py Outdated Show resolved Hide resolved
finetrainers/dataset.py Show resolved Hide resolved
finetrainers/ltx_video/ltx_video_lora.py Show resolved Hide resolved
finetrainers/trainer.py Show resolved Hide resolved
finetrainers/trainer.py Show resolved Hide resolved
@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 23, 2024

I believe I've addressed all reviews. Proceeding with merge. Follow-ups:

@sayakpaul
Copy link
Collaborator

Works for me, thanks for chalking that out. I can also support training with bnb-quantized checkpoints similar to https://github.com/huggingface/diffusers/tree/main/examples/research_projects/flux_lora_quantization (something that has worked like wonders in the LLM world and also for Flux) ;)

@sayakpaul sayakpaul merged commit cf9be17 into main Dec 23, 2024
@sayakpaul sayakpaul deleted the condition-precomputation branch December 23, 2024 14:30
@sayakpaul
Copy link
Collaborator

sayakpaul commented Dec 23, 2024

Will add:

  • DeepSpeed support
  • Separating trainer classes w.r.t. training algorithms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants