-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce OptimizerArgs and add support for GaLore #1192
Conversation
After our discussion today, I think we should only enable vanilla Galore for now, not worrying about the LoRA support. We can look into LoRA support later if there is high-demand. I am getting some precision-related errors when trying to use it with LoRA, which has likely something to do with the precision that is used by the Galore optimizer under the hood. I am expecting the Galore package to evolve in the upcoming weeks and months, and we can then revist if LoRA works without us having to make additional tweaks to the GaLore optimizer etc. |
@rasbt How much of an improvement in VRAM consumption you saw with LoRA+GaLore? |
The combination of LoRA + GaLore doesn't really work yet due precision mismatches when merging the LoRA weighs at the end so it didn't get to the code line that prints the memory usage. I could comment out the merging and try it again, but I think let's just focus on Galore for full finetuning first. Like you said, I don't expect a big improvement when combined with LoRA. |
I changed the FullAdamWlitgpt finetune full \
--checkpoint_dir checkpoints/microsoft/phi-2/ \
--train.max_steps 5
# Training time: 32.76s
# Memory used: 55.84 GB GaLore
GaLore 8-bit
LoRAAdamW
GaLore
GaLore 8-bit
AdapterAdamW
GaLore
GaLore 8-bit
Adapter v2AdamW
GaLore
GaLore 8-bit
Pretrain (Pythia 14M)AdamW
GaLore
GaLore 8-bit
|
I tried many things and even ended up replacing all instances of torch's AdamW with Galore's to make sure it's actually used, but for for some reason, I cannot see any difference in memory usage when pretraining. Mind boggling. |
I changed the hardcoded galore arguments to general So, what's new is that we now have optimizer kwargs. E.g., this adds
What do you think about this approach and interface @carmocca @lantiga @awaelchli ? |
The jsonargparse-y way of doing this would be to instead specify which Optimizer class you want to select to let the parser pull out the arguments of said class. For example, that is exactly how the data is selected and parsed |
OMG I made it way more complicated than it need be 🤦♂️. Thanks for the hint. Now I know. |
After trying this, I realize that this may not be cleanly possible because optimizers require We could probably have this jsonargparse approach for PyTorch native optimizers, but I don't think it will be easy to support Galore this way in a non-hacky way. I can make a PR with just PyTorch optimizer support and then we can decide whether which route want to go, only supporting PyTorch optimizers, or revisiting this implementation here with our own |
Yes, we cannot have jsonargparse instantiate the class directly for that reason. But you can still tell it to add all the arguments of a class (or classes) into a group of args, basically getting you OptimizerArgs automatically for that class. Then those args can be used to instantiate the real optimizer instance later in the script. The PyTorch Lightning CLI implementation works that way: https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/cli.py#L154-L177 |
Arg, I am still struggling with this. I.e.,
works without problem but then even if I don't do anything else, jsonargparse tries to initialize it already via
before I can pass it to anything else. Not sure how to avoid that. |
You can start by understanding this minimal example: import torch
import jsonargparse
parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()
print(args) python example.py --optimizer Adam
Namespace(optimizer=Namespace(class_path='torch.optim.Adam', init_args=Namespace(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None))) |
And here's how you would use the above to instantiate the optimizer: from typing import Any, Tuple, Dict, Union
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
model = torch.nn.Linear(1, 1)
optimizer = instantiate_class(model.parameters(), init=args["optimizer"])
print(optimizer) We define |
The current implementation adds GaLore to the full finetuning script.
Example
Discuss
We could also add it to LoRA
GaloreAdamW8Bit
)I specified the galore args similar to what we do with lora. But since this is more an addon to existing methods like
full
andlora
, should we maybe make this part ofTrainArgs
?We can also think about making a dedicated subcommand like for qlora in the future. Ie..,
Todos
NotImplementedError: Cannot merge the pretrained weights of type torch.float16 and LoRA weights of type torch.float32
galore_8bit = True
?Fixes #1075