Skip to content

Commit

Permalink
fix get_linear_nonlinear_params
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 6, 2024
1 parent fc17c57 commit bd73193
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,12 @@ def choose_logger(
def get_linear_nonlinear_params(model):
linear_params = []
nonlinear_params = []

for module in model.modules():
if isinstance(module, torch.nn.Linear):
linear_params.extend(list(model.parameters()))
linear_params.extend(list(module.parameters()))
else:
nonlinear_params.extend(list(model.parameters()))

# Make extra sure that there is no overlap
linear_params = list(set(linear_params) - set(nonlinear_params))
nonlinear_params.extend(list(module.parameters()))
linear_params = list(set(linear_params))
nonlinear_params = list(set(nonlinear_params) - set(linear_params))
return linear_params, nonlinear_params

0 comments on commit bd73193

Please sign in to comment.