Skip to content

Commit

Permalink
fix n_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed May 30, 2024
1 parent d1a3baa commit 9d0bbb3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions docs/guides/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ These files are called checkpoints (like video game save files - computer scient
model = finetune.FinetuneableZoobotClassifier(
name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', # which pretrained model to download
num_classes=2,
n_layers=0
n_blocks=0
)
You can see the list of pretrained models at :doc:`/pretrained_models`.

What about the other arguments?
When loading the checkpoint, FinetuneableZoobotClassifier will automatically change the head layer to suit a classification problem (hence, ``Classifier``).
``num_classes=2`` specifies how many classes we have, Here, two classes (a.k.a. binary classification).
``n_layers=0`` specifies how many layers (other than the output layer) we want to finetune.
0 indicates no other layers, so we will only be changing the weights of the output layer.
``n_blocks=0`` specifies how many inner blocks (groups of layers, excluding the output layer) we want to finetune.
0 indicates no other blocks, so we will only be changing the weights of the output layer.


Prepare Galaxy Data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
model = finetune.FinetuneableZoobotClassifier(
name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
num_classes=2,
n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
n_blocks=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
)
# under the hood, this does:
# encoder = finetune.load_pretrained_encoder(checkpoint_loc)
Expand Down
3 changes: 1 addition & 2 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class FinetuneableZoobotAbstract(pl.LightningModule):
prog_bar (bool, optional): Print progress bar during finetuning. Defaults to True.
visualize_images (bool, optional): Upload example images to WandB. Good for debugging but slow. Defaults to False.
seed (int, optional): random seed to use. Defaults to 42.
n_layers: No effect, deprecated. Use n_blocks instead.
"""

def __init__(
Expand All @@ -90,7 +89,6 @@ def __init__(
learning_rate=1e-4, # 10x lower than typical, you may like to experiment
dropout_prob=0.5,
always_train_batchnorm=False, # temporarily deprecated
# n_layers=0, # for backward compat., n_blocks preferred. Now removed in v2.
# these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already
cosine_schedule=False,
warmup_epochs=0,
Expand All @@ -101,6 +99,7 @@ def __init__(
# debugging utils
prog_bar=True,
visualize_images=False, # upload examples to wandb, good for debugging
n_layers=0, # deprecated (no effect) but can't remove yet as is an arg in some saved checkpoints
seed=42,
):
super().__init__()
Expand Down

0 comments on commit 9d0bbb3

Please sign in to comment.