Skip to content

Commit

Permalink
add a warning
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Mar 2, 2024
1 parent ce2c80a commit eaa98ce
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,13 @@ def __init__(
# self.encoder_dim = 9216
else:
self.encoder_dim = define_model.get_encoder_dim(self.encoder)
self.n_blocks = n_blocks
logging.info('Blocks to finetune: {}'.format(n_layers))

# for backwards compat.
if n_layers:
logging.warning('FinetuneableZoobot(n_layers) is now renamed to n_blocks, please update to pass n_blocks instead! For now, setting n_blocks=n_layers')
self.n_blocks = n_layers
logging.info('Layers to finetune: {}'.format(n_layers))
else:
self.n_blocks = n_blocks

self.learning_rate = learning_rate
self.lr_decay = lr_decay
Expand Down Expand Up @@ -243,13 +242,16 @@ def configure_optimizers(self):


logging.info('param groups: {}'.format(len(params)))

# because it iterates through the generators, THIS BREAKS TRAINING so only uncomment to debug params
# for param_group_n, param_group in enumerate(params):
# shapes_within_param_group = [p.shape for p in list(param_group['params'])]
# logging.debug('param group {}: {}'.format(param_group_n, shapes_within_param_group))
# print('head params to optimize', [p.shape for p in params[0]['params']]) # head only
# print(list(param_group['params']) for param_group in params)
# exit()
# Initialize AdamW optimizer

opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict
logging.info('Optimizer ready, configuring scheduler')

Expand Down

0 comments on commit eaa98ce

Please sign in to comment.