Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce OptimizerArgs and add support for GaLore #1192

Closed
wants to merge 28 commits into from
Closed

Introduce OptimizerArgs and add support for GaLore #1192

wants to merge 28 commits into from

Conversation

rasbt
Copy link
Collaborator

@rasbt rasbt commented Mar 25, 2024

The current implementation adds GaLore to the full finetuning script.

Example

# regular
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5 

# Training time: 14.13s
# Memory used: 3.44 GB



# with galore
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5  \
  --galore.use_galore true

# Training time: 23.59s
# Memory used: 3.44 GB



# with 8bit galore
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --data Alpaca2k \
  --train.max_steps 5  \
  --galore.use_galore true \
  --galore.galore_8bit

# Training time: 17.96s
# Memory used: 2.47 GB

Discuss

We could also add it to LoRA

  • this would require a check that GaLore is only used when QLoRA is disabled
  • we can actually use it with some bnb precision settings (this would be supported according to them via GaloreAdamW8Bit)

I specified the galore args similar to what we do with lora. But since this is more an addon to existing methods like full and lora , should we maybe make this part of TrainArgs?

We can also think about making a dedicated subcommand like for qlora in the future. Ie..,

litgpt finetune full --config ... 

litgpt finetune lora --config ... 

litgpt finetune qlora --config ... [in progress]

litgpt finetune galore --config ... [maybe in future]

Todos

  • Add Galore for full finetuning
  • Check if default args are good
  • Add docstrings
  • Discuss if we use TrainArgs (see above)
  • Add Galore for lora finetuning (investigate NotImplementedError: Cannot merge the pretrained weights of type torch.float16 and LoRA weights of type torch.float32
  • Throw error if galore and qlora are used at the same time if Qlora is not 8bit
  • Should we also allow 8bit galore without QLoRA? I'd say yes. How? galore_8bit = True?
  • Update full and lora config files
  • Add galore for pretraining
  • Consider adding it for adapter and adapter v2
  • Add tests
  • Restrict to single GPU training
  • Add GaLore package to the acknowledgements section
  • Add documentation
  • Add configs YAMLs and benchmarks

Fixes #1075

@rasbt rasbt mentioned this pull request Mar 25, 2024
@rasbt rasbt marked this pull request as draft March 25, 2024 22:00
pyproject.toml Outdated Show resolved Hide resolved
litgpt/finetune/lora.py Outdated Show resolved Hide resolved
@rasbt
Copy link
Collaborator Author

rasbt commented Mar 26, 2024

After our discussion today, I think we should only enable vanilla Galore for now, not worrying about the LoRA support. We can look into LoRA support later if there is high-demand. I am getting some precision-related errors when trying to use it with LoRA, which has likely something to do with the precision that is used by the Galore optimizer under the hood. I am expecting the Galore package to evolve in the upcoming weeks and months, and we can then revist if LoRA works without us having to make additional tweaks to the GaLore optimizer etc.

@Andrei-Aksionov
Copy link
Collaborator

@rasbt How much of an improvement in VRAM consumption you saw with LoRA+GaLore?
With any PEFT algo the amount of parameters to optimize shouldn't be that significant.

@rasbt
Copy link
Collaborator Author

rasbt commented Mar 27, 2024

The combination of LoRA + GaLore doesn't really work yet due precision mismatches when merging the LoRA weighs at the end so it didn't get to the code line that prints the memory usage. I could comment out the merging and try it again, but I think let's just focus on Galore for full finetuning first. Like you said, I don't expect a big improvement when combined with LoRA.

@rasbt
Copy link
Collaborator Author

rasbt commented May 3, 2024

I changed the GaloreArgs to OptimizerArgs and here are some results for phi-2. What's puzzling is the pretraining performance. I couldn't find the issue and may need to investigate more. Also need to update the config files once we settled on the API.

Full

AdamW

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 32.76s
# Memory used: 55.84 GB

GaLore

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 128.55s
# Memory used: 36.14 GB

GaLore 8-bit

litgpt finetune full \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 128.68s
# Memory used: 33.81 GB

LoRA

AdamW

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 36.43s
# Memory used: 18.56 GB

GaLore

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 25.98s
# Memory used: 18.56 GB

GaLore 8-bit

litgpt finetune lora \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 26.01s
# Memory used: 18.54 GB

Adapter

AdamW

litgpt finetune adapter \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 31.16s
# Memory used: 17.94 GB

GaLore

litgpt finetune adapter \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"

# Training time: 24.81s
# Memory used: 17.94 GB

GaLore 8-bit

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"

# Training time: 26.36s
# Memory used: 20.10 GB

Adapter v2

AdamW

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5

# Training time: 26.35s
# Memory used: 20.11 GB

GaLore

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw"
# Training time: 26.31s
# Memory used: 20.11 GB

GaLore 8-bit

litgpt finetune adapter_v2 \
  --checkpoint_dir checkpoints/microsoft/phi-2/ \
  --train.max_steps 5 \
  --optim.optimizer "galore_adamw_8bit"
# Training time: 26.26s
# Memory used: 20.10 GB

Pretrain (Pythia 14M)

AdamW

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000

# Training time: 34.07s
# Memory used: 1.44 GB

GaLore

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000 \
  --optim.optimizer "galore_adamw"

# Training time: 25.31s
# Memory used: 1.44 GB

GaLore 8-bit

litgpt pretrain \
  --model_name pythia-14m \
  --tokenizer_dir checkpoints/EleutherAI/pythia-14m/ \
  --data TextFiles \
  --data.train_data_path "custom_texts" \
  --train.max_tokens 100_000 \
  --optim.optimizer "galore_adamw_8bit"
# Training time: 25.31s
# Memory used: 1.44 GB

@rasbt
Copy link
Collaborator Author

rasbt commented May 6, 2024

I tried many things and even ended up replacing all instances of torch's AdamW with Galore's to make sure it's actually used, but for for some reason, I cannot see any difference in memory usage when pretraining. Mind boggling.

@rasbt rasbt marked this pull request as ready for review May 9, 2024 18:27
@rasbt
Copy link
Collaborator Author

rasbt commented May 9, 2024

I changed the hardcoded galore arguments to general extra_kwargs so they could be used for other optimizer options as well. This way it adds less clutter to the CLI.

So, what's new is that we now have optimizer kwargs. E.g., this adds

# Optimizer-related arguments
optim: 
  # Which optimizer to use. Possible choices: "adamw", "galore_adamw", "galore_adamw_8bit". (type: Optional[str], default: "adamw")
  optimizer: "adamw"

  #   (type: float, default: 0.0003)
  learning_rate: 0.0002

  #   (type: float, default: 0.02)
  weight_decay: 0.0

  #   (type: float, default: 0.9)
  beta1: 0.9

  #   (type: float, default: 0.95)
  beta2: 0.95

  # Additional optimizer keyword arguments, for example, "rank=8,update_proj_gap=200" for GaLore. (type: Optional[str], default: None)
  extra_kwargs:

What do you think about this approach and interface @carmocca @lantiga @awaelchli ?

@rasbt rasbt changed the title Add support for GaLore Introduce OptimizerArgs and add support for GaLore May 9, 2024
@rasbt rasbt requested a review from williamFalcon as a code owner May 10, 2024 00:42
@carmocca
Copy link
Contributor

The jsonargparse-y way of doing this would be to instead specify which Optimizer class you want to select to let the parser pull out the arguments of said class. For example, that is exactly how the data is selected and parsed

@rasbt
Copy link
Collaborator Author

rasbt commented May 10, 2024

OMG I made it way more complicated than it need be 🤦‍♂️. Thanks for the hint. Now I know.

@rasbt
Copy link
Collaborator Author

rasbt commented May 10, 2024

After trying this, I realize that this may not be cleanly possible because optimizers require params as positional argument. So we would have to wrap the optimizer in our own optimizer class. The other problem is with the Galore optimizer, which needs to split the params into regular params and galore params prior to passing them. It kind of gets ugly real quick.

We could probably have this jsonargparse approach for PyTorch native optimizers, but I don't think it will be easy to support Galore this way in a non-hacky way.

I can make a PR with just PyTorch optimizer support and then we can decide whether which route want to go, only supporting PyTorch optimizers, or revisiting this implementation here with our own extra_args parsing.

@carmocca
Copy link
Contributor

Yes, we cannot have jsonargparse instantiate the class directly for that reason.

But you can still tell it to add all the arguments of a class (or classes) into a group of args, basically getting you OptimizerArgs automatically for that class. Then those args can be used to instantiate the real optimizer instance later in the script.

The PyTorch Lightning CLI implementation works that way: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/cli.py#L154-L177

@rasbt
Copy link
Collaborator Author

rasbt commented May 10, 2024

Arg, I am still struggling with this.

I.e.,

 litgpt finetune full --optimizer.help torch.optim.AdamW    

works without problem but then even if I don't do anything else, jsonargparse tries to initialize it already via

litgpt finetune full  ... --optimizer torch.optim.AdamW 

before I can pass it to anything else. Not sure how to avoid that.
I think I need to study jsonargparse a bit better because right now I feel like I am trying to hack things together somehow ...

@carmocca
Copy link
Contributor

You can start by understanding this minimal example:

import torch
import jsonargparse

parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()
print(args)
python example.py --optimizer Adam   
Namespace(optimizer=Namespace(class_path='torch.optim.Adam', init_args=Namespace(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None)))

@carmocca
Copy link
Contributor

And here's how you would use the above to instantiate the optimizer:

from typing import Any, Tuple, Dict, Union

def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
    """Instantiates a class with the given args and init.

    Args:
        args: Positional arguments required for instantiation.
        init: Dict of the form {"class_path":...,"init_args":...}.

    Returns:
        The instantiated class object.

    """
    kwargs = init.get("init_args", {})
    if not isinstance(args, tuple):
        args = (args,)
    class_module, class_name = init["class_path"].rsplit(".", 1)
    module = __import__(class_module, fromlist=[class_name])
    args_class = getattr(module, class_name)
    return args_class(*args, **kwargs)


model = torch.nn.Linear(1, 1)
optimizer = instantiate_class(model.parameters(), init=args["optimizer"])
print(optimizer)

We define instantiate_class for the PyTorch Lightning CLI here: https://github.com/Lightning-AI/pytorch-lightning/blob/90d04b5b86f37994cdceccc6de32f0e93b1cc7f0/src/lightning/pytorch/cli.py#L752-L769

@rasbt rasbt mentioned this pull request May 10, 2024
28 tasks
@rasbt rasbt closed this Jul 3, 2024
@rasbt rasbt deleted the galore branch September 24, 2024 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

support GaLore
3 participants