From 573c2510aa205acb161cf8a1e46915788b316a5f Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Thu, 30 May 2024 14:41:23 -0400 Subject: [PATCH] refactor sanity check and encoder_dim --- zoobot/pytorch/training/finetune.py | 84 ++++++++++++++++------------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 405594a2..65a55a38 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -1,6 +1,5 @@ import logging import os -import warnings from functools import partial from typing import Any, Optional, Union @@ -118,7 +117,7 @@ def __init__( # if you need the encoder to recreate, pass when loading checkpoint e.g. # FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder) - if name is not None: + if name is not None: # will load from Hub assert encoder is None, 'Cannot pass both name and encoder to use' if 'greyscale' in name: # I'm not sure why timm is happy to convert color model stem to greyscale @@ -130,23 +129,25 @@ def __init__( self.encoder = timm.create_model(name, num_classes=0, pretrained=True, **timm_kwargs) self.encoder_dim = self.encoder.num_features - elif zoobot_checkpoint_loc is not None: + elif zoobot_checkpoint_loc is not None: # will load from local checkpoint assert encoder is None, "Cannot pass both checkpoint to load and encoder to use" self.encoder = load_pretrained_zoobot( zoobot_checkpoint_loc ) # extracts the timm encoder self.encoder_dim = self.encoder.num_features - else: + else: # passed encoder in-memory directly assert ( zoobot_checkpoint_loc is None ), "Cannot pass both checkpoint to load and encoder to use" assert encoder is not None, "Must pass either checkpoint to load or encoder to use" self.encoder = encoder - # work out encoder dim 'manually' - if isinstance(self.encoder, timm.models.VisionTransformer): - self.encoder_dim = self.encoder.embed_dim - else: + # find out encoder dimension + if hasattr(self.encoder, 'num_features'): # timm models generally use this + self.encoder_dim = self.encoder.num_features + elif hasattr(self.encoder, 'embed_dim'): # timm.models.VisionTransformer uses this + self.encoder_dim = self.encoder.embed_dim + else: # resort to manual estimate self.encoder_dim = define_model.get_encoder_dim(self.encoder) self.n_blocks = n_blocks @@ -422,6 +423,7 @@ def load_from_name(cls, name: str, **kwargs): ) # trained on GPU, may need map_location='cpu' if you get a device error + class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract): """ Pretrained Zoobot model intended for finetuning on a classification problem. @@ -437,11 +439,12 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract): num_classes (int): num. of target classes (e.g. 2 for binary classification). label_smoothing (float, optional): See torch cross_entropy_loss docs. Defaults to 0. class_weights (arraylike, optional): See torch cross_entropy_loss docs. Defaults to None. + run_linear_sanity_check (bool, optional): Before fitting, use sklearn to fit a linear model. Defaults to False. """ def __init__( - self, num_classes: int, label_smoothing=0.0, class_weights=None, **super_kwargs + self, num_classes: int, label_smoothing=0.0, class_weights=None, run_linear_sanity_check=False, **super_kwargs ) -> None: super().__init__(**super_kwargs) @@ -469,6 +472,8 @@ def __init__( self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes) self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes) + self.run_linear_sanity_check = run_linear_sanity_check + def step_to_dict(self, y, y_pred, loss): y_class_preds = torch.argmax(y_pred, axis=1) # type: ignore return { @@ -478,34 +483,6 @@ def step_to_dict(self, y, y_pred, loss): "class_predictions": y_class_preds, } - # Sanity check embeddings with linear evaluation first - def on_train_start(self) -> None: - with torch.no_grad(): - embeddings, labels = {"train": [], "val": []}, {"train": [], "val": []} - - # Get validation set embeddings - for x, y in self.trainer.datamodule.val_dataloader(): - embeddings["val"] += self.encoder(x.to(self.device)).cpu() - labels["val"] += y - - # Get train set embeddings - for x, y in self.trainer.datamodule.train_dataloader(): - embeddings["train"] += self.encoder(x.to(self.device)).cpu() - labels["train"] += y - - # this is linear *train* acc but that's okay, simply test of features - model = linear_model.LogisticRegression(penalty=None, max_iter=200) - model.fit(embeddings["train"], labels["train"]) - - self.log( - "finetuning/linear_eval/val", - accuracy_score(labels["val"], model.predict(embeddings["val"])), - ) - self.log( - "finetuning/linear_eval/train", - accuracy_score(labels["train"], model.predict(embeddings["train"])), - ) - # doesn't need to be torchmetric, only happens in one go? but distributed def on_train_batch_end(self, step_output, *args): super().on_train_batch_end(step_output, *args) @@ -564,6 +541,39 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx): ] self.logger.log_image(key="val_images", images=images, caption=captions) # type: ignore + # Sanity check embeddings with linear evaluation first + def on_train_start(self) -> None: + if self.run_linear_sanity_check: # default False + self.linear_sanity_check() + + def linear_sanity_check(self): + # only implemented on Zoobot...Classifier as assumes accuracy + with torch.no_grad(): + embeddings, labels = {"train": [], "val": []}, {"train": [], "val": []} + + # Get validation set embeddings + for x, y in self.trainer.datamodule.val_dataloader(): + embeddings["val"] += self.encoder(x.to(self.device)).cpu() + labels["val"] += y + + # Get train set embeddings + for x, y in self.trainer.datamodule.train_dataloader(): + embeddings["train"] += self.encoder(x.to(self.device)).cpu() + labels["train"] += y + + # this is linear *train* acc but that's okay, simply test of features + model = linear_model.LogisticRegression(penalty=None, max_iter=200) + model.fit(embeddings["train"], labels["train"]) + + self.log( + "finetuning/linear_eval/val", + accuracy_score(labels["val"], model.predict(embeddings["val"])), + ) + self.log( + "finetuning/linear_eval/train", + accuracy_score(labels["train"], model.predict(embeddings["train"])), + ) + # doesn't need to be torchmetric, only happens in one go? but distributed class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract): """