-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Factor out the Trainer
class
#130
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just some concerns and ideas for first draft.
I believe we're going to have some merge conflicts from #129 and would prefer that be merged first because of significant memory savings from precomputation. Happy to help resolve any merge conflicts and proceed further here
import os | ||
import sys | ||
|
||
base_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't prefer this. train.py
should be the entry point and determine which trainer to use based on provided arguments. So prefer to not having any changes to import style.
|
||
|
||
class LTXTrainer(Trainer): | ||
def prepare_models(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from tqdm import tqdm | ||
|
||
from .args import Args, validate_args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not import it this way by changing base paths through os and sys modules. I prefer the same way it's done in diffusers with relative imports but if there's benefit to doing it this way, could you help me understand?
def calculate_loss_weights(self, **kwargs): | ||
raise NotImplementedError | ||
|
||
def sort_out_checkpoint_to_resume_from(self, accelerator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would maybe make this a helper method in checkpoint_utils.py and pass all the relevant arguments
path = dirs[-1] if len(dirs) > 0 else None | ||
|
||
if path is None: | ||
accelerator.print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use accelerator.print. logger.info should work, and even if this is moved to a different file, get_logger("finetrainers") will allow us to perform the logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was here before. I did not use print.
|
||
return initial_global_step, global_step, first_epoch | ||
|
||
def save_intermediate_checkpoint(self, step, accelerator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also move this into a utils file making sure to pass all required arguments and import here
accelerator.save_state(save_path) | ||
logger.info(f"Saved state to {save_path}") | ||
|
||
def save_final_checkpoint(self, accelerator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoRA and full finetuning have different ways of handling this, so would make utility functions here too and import. For now, since we only have lora, this is okay to keep here too
@a-r-r-o-w this is how I am envisioning the
Trainer
class.To test:
Collapse
Of course, this will break HunyuanVideo, so, I am happy to
dev
, merge the current PR into that.dev
.Now, we should be safe to merge
dev
intomain
.Training logs: https://wandb.ai/sayakpaul/finetrainers-ltxv/runs/qcqbqjxq
WDYT?