Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 8, 2024
1 parent 4bc9c51 commit 88634cc
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 49 deletions.
3 changes: 1 addition & 2 deletions litgpt/external/galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def get_galore_params(model):
if not any(target_key in module_name for target_key in target_modules_list):
continue

print('enable GaLore for weights in module: ', module_name)
galore_params.append(module.weight)
id_galore_params = [id(p) for p in galore_params]
# make parameters without "rank" to another group
Expand Down Expand Up @@ -377,7 +376,7 @@ def __init__(
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
no_deprecation_warning: bool = True,
):
if not no_deprecation_warning:
warnings.warn(
Expand Down
12 changes: 4 additions & 8 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
get_linear_nonlinear_params,
init_out_dir,
load_checkpoint,
num_parameters,
Expand Down Expand Up @@ -164,16 +163,13 @@ def main(
optimizer_cls = torch.optim.AdamW

elif optim.optimizer in ("galore_adamw", "galore_adamw_8bit"):
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
raise ValueError("The combinatiomn of QLoRA and GaLore is currently not supported.")
from litgpt.external.galore import get_galore_params

linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters;
# we could add options to target specific layers for AdamW and GaLore later
regular_params, galore_params = get_galore_params(model)
trainable_params = [
{'params': nonlinear_params},
{'params': regular_params},
{
'params': linear_params,
'params': galore_params,
'rank': optim.galore_r,
'update_proj_gap': optim.galore_update_proj_gap,
'scale': optim.galore_scale,
Expand Down
12 changes: 4 additions & 8 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
get_linear_nonlinear_params,
init_out_dir,
load_checkpoint,
num_parameters,
Expand Down Expand Up @@ -165,16 +164,13 @@ def main(
optimizer_cls = torch.optim.AdamW

elif optim.optimizer in ("galore_adamw", "galore_adamw_8bit"):
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
raise ValueError("The combinatiomn of QLoRA and GaLore is currently not supported.")
from litgpt.external.galore import get_galore_params

linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters;
# we could add options to target specific layers for AdamW and GaLore later
regular_params, galore_params = get_galore_params(model)
trainable_params = [
{'params': nonlinear_params},
{'params': regular_params},
{
'params': linear_params,
'params': galore_params,
'rank': optim.galore_r,
'update_proj_gap': optim.galore_update_proj_gap,
'scale': optim.galore_scale,
Expand Down
17 changes: 8 additions & 9 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
get_linear_nonlinear_params,
load_checkpoint,
init_out_dir,
num_parameters,
Expand Down Expand Up @@ -200,17 +199,17 @@ def main(
if isinstance(fabric.strategy.precision, BitsandbytesPrecision):
raise ValueError("The combination of QLoRA and GaLore is currently not supported.")

linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters;
# we could add options to target specific layers for AdamW and GaLore later
from litgpt.external.galore import get_galore_params

regular_params, galore_params = get_galore_params(model)
trainable_params = [
{'params': nonlinear_params},
{'params': regular_params},
{
'params': linear_params,
'params': galore_params,
'rank': optim.galore_r,
'update_proj_gap': optim.galore_update_proj_gap,
'scale': optim.galore_scale,
'proj_type': optim.galore_proj_type
'update_proj_gap': optim.galore_update_proj_gap,
'scale': optim.galore_scale,
'proj_type': optim.galore_proj_type
}
]
if optim.optimizer == "galore_adamw_8bit":
Expand Down
11 changes: 5 additions & 6 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
chunked_cross_entropy,
copy_config_files,
get_default_supported_precision,
get_linear_nonlinear_params,
init_out_dir,
num_parameters,
parse_devices,
Expand Down Expand Up @@ -190,13 +189,13 @@ def main(
optimizer_cls = torch.optim.AdamW

elif optim.optimizer in ("galore_adamw", "galore_adamw_8bit"):
linear_params, nonlinear_params = get_linear_nonlinear_params(model)
# Currently apply galore to all parameters;
# we could add options to target specific layers for AdamW and GaLore later
from litgpt.external.galore import get_galore_params

regular_params, galore_params = get_galore_params(model)
trainable_params = [
{'params': nonlinear_params},
{'params': regular_params},
{
'params': linear_params,
'params': galore_params,
'rank': optim.galore_r,
'update_proj_gap': optim.galore_update_proj_gap,
'scale': optim.galore_scale,
Expand Down
16 changes: 0 additions & 16 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,19 +485,3 @@ def choose_logger(
if logger_name == "wandb":
return WandbLogger(project=name, resume=resume, **kwargs)
raise ValueError(f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'.")


## REMOVE
def get_linear_nonlinear_params(model):
linear_params = []
nonlinear_params = []
for module in model.modules():
if isinstance(module, torch.nn.Linear):
linear_params.extend(list(module.parameters()))
else:

nonlinear_params.extend(list(module.parameters()))
linear_params = list(set(linear_params))
nonlinear_params = list(set(nonlinear_params) - set(linear_params))
return linear_params, nonlinear_params

0 comments on commit 88634cc

Please sign in to comment.