diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index aa02977fb9..2ec71784e7 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -145,7 +145,6 @@ def main( model = fabric.setup_module(model) - trainable_params = [p for p in model.parameters() if p.requires_grad] if isinstance(fabric.strategy.precision, BitsandbytesPrecision): import bitsandbytes as bnb @@ -153,7 +152,7 @@ def main( else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - trainable_params, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) + model.parameters(), lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) ) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index 197bcd0bba..86526a58e5 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -145,7 +145,6 @@ def main( model = fabric.setup_module(model) - trainable_params = [p for p in model.parameters() if p.requires_grad] if isinstance(fabric.strategy.precision, BitsandbytesPrecision): import bitsandbytes as bnb @@ -153,7 +152,7 @@ def main( else: optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - trainable_params, lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) + model.parameters(), lr=train.learning_rate, weight_decay=train.weight_decay, betas=(train.beta1, train.beta2) ) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)