diff --git a/docs/guides/finetuning.rst b/docs/guides/finetuning.rst index bce4fb56..3bac92e7 100755 --- a/docs/guides/finetuning.rst +++ b/docs/guides/finetuning.rst @@ -60,7 +60,7 @@ These files are called checkpoints (like video game save files - computer scient model = finetune.FinetuneableZoobotClassifier( name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', # which pretrained model to download num_classes=2, - n_layers=0 + n_blocks=0 ) You can see the list of pretrained models at :doc:`/pretrained_models`. @@ -68,8 +68,8 @@ You can see the list of pretrained models at :doc:`/pretrained_models`. What about the other arguments? When loading the checkpoint, FinetuneableZoobotClassifier will automatically change the head layer to suit a classification problem (hence, ``Classifier``). ``num_classes=2`` specifies how many classes we have, Here, two classes (a.k.a. binary classification). -``n_layers=0`` specifies how many layers (other than the output layer) we want to finetune. -0 indicates no other layers, so we will only be changing the weights of the output layer. +``n_blocks=0`` specifies how many inner blocks (groups of layers, excluding the output layer) we want to finetune. +0 indicates no other blocks, so we will only be changing the weights of the output layer. Prepare Galaxy Data diff --git a/setup.py b/setup.py index 0faa3772..ab906e80 100755 --- a/setup.py +++ b/setup.py @@ -76,8 +76,7 @@ 'albumentations', 'pyro-ppl>=1.8.0', 'torchmetrics==0.11.0', - 'timm >= 0.9.10', - 'galaxy_datasets == 0.0.17' + 'timm >= 0.9.10' ], # TODO may add narval/Digital Research Canada config 'tensorflow': [ # WARNING now deprecated @@ -117,6 +116,6 @@ 'webdataset', # for reading webdataset files 'huggingface_hub', # login may be required 'setuptools', # no longer pinned - 'galaxy-datasets>=0.0.18' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets) + 'galaxy-datasets>=0.0.21' # for dataset loading in both TF and Torch (see github/mwalmsley/galaxy-datasets) ] ) diff --git a/zoobot/pytorch/estimators/define_model.py b/zoobot/pytorch/estimators/define_model.py index f5980492..593ab800 100755 --- a/zoobot/pytorch/estimators/define_model.py +++ b/zoobot/pytorch/estimators/define_model.py @@ -2,15 +2,15 @@ from functools import partial from typing import List -import torch -from torch import nn import pytorch_lightning as pl -import torchmetrics import timm +import torch +import torchmetrics +from torch import nn -from zoobot.shared import schemas -from zoobot.pytorch.estimators import efficientnet_custom, custom_layers +from zoobot.pytorch.estimators import custom_layers, efficientnet_custom from zoobot.pytorch.training import losses, schedulers +from zoobot.shared import schemas # overall strategy # timm for defining complicated pytorch modules @@ -32,7 +32,7 @@ # FinetuneableZoobotClassifier(pretrained_model.encoder, optim_args, task_args) # (same approach for FinetuneableZoobotTree) -# to use just the encoder later: +# to use just the encoder later: # encoder = load_pretrained_encoder(pyramid=False) # when pyramid=True, reset the timm model to pull lightning features (TODO) @@ -42,7 +42,6 @@ # LinearClassifier(output_dim, dropout) - class GenericLightningModule(pl.LightningModule): """ All Zoobot models use the lightningmodule API and so share this structure @@ -50,66 +49,82 @@ class GenericLightningModule(pl.LightningModule): only assumes an encoder and a head """ - def __init__( - self, - *args, # to be saved as hparams - ): + # args will be saved as hparams + def __init__(self, *args): super().__init__() - self.save_hyperparameters() # saves all args by default - - - def setup_metrics(self, nan_strategy='error'): # may sometimes want to ignore nan even in main metrics - self.val_accuracy = torchmetrics.Accuracy(task='binary') + # saves all args by default + self.save_hyperparameters() + + # may sometimes want to ignore nan even in main metrics + def setup_metrics(self, nan_strategy="error"): + self.val_accuracy = torchmetrics.Accuracy(task="binary") + + self.loss_metrics = torch.nn.ModuleDict( + { + "train/supervised_loss": torchmetrics.MeanMetric( + nan_strategy=nan_strategy + ), + "validation/supervised_loss": torchmetrics.MeanMetric( + nan_strategy=nan_strategy + ), + "test/supervised_loss": torchmetrics.MeanMetric( + nan_strategy=nan_strategy + ), + } + ) - self.loss_metrics = torch.nn.ModuleDict({ - 'train/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), - 'validation/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), - 'test/supervised_loss': torchmetrics.MeanMetric(nan_strategy=nan_strategy), - }) - # TODO handle when schema doesn't exist question_metric_dict = {} - for step_name in ['train', 'validation', 'test']: - question_metric_dict.update({ - step_name + '/question_loss/' + question.text: torchmetrics.MeanMetric(nan_strategy='ignore') - for question in self.schema.questions - }) + for step_name in ["train", "validation", "test"]: + question_metric_dict.update( + { + step_name + + "/question_loss/" + + question.text: torchmetrics.MeanMetric(nan_strategy="ignore") + for question in self.schema.questions + } + ) self.question_loss_metrics = torch.nn.ModuleDict(question_metric_dict) campaigns = schema_to_campaigns(self.schema) campaign_metric_dict = {} - for step_name in ['train', 'validation', 'test']: - campaign_metric_dict.update({ - step_name + '/campaign_loss/' + campaign: torchmetrics.MeanMetric(nan_strategy='ignore') - for campaign in campaigns - }) + for step_name in ["train", "validation", "test"]: + campaign_metric_dict.update( + { + step_name + + "/campaign_loss/" + + campaign: torchmetrics.MeanMetric(nan_strategy="ignore") + for campaign in campaigns + } + ) self.campaign_loss_metrics = torch.nn.ModuleDict(campaign_metric_dict) - def forward(self, x): assert x.shape[1] < 4 # torchlike BCHW x = self.encoder(x) return self.head(x) - + def make_step(self, batch, step_name): x, labels = batch predictions = self(x) # by default, these are Dirichlet concentrations - loss = self.calculate_loss_and_update_loss_metrics(predictions, labels, step_name) - outputs = {'loss': loss, 'predictions': predictions, 'labels': labels} + loss = self.calculate_loss_and_update_loss_metrics( + predictions, labels, step_name + ) + outputs = {"loss": loss, "predictions": predictions, "labels": labels} # self.update_other_metrics(outputs, step_name=step_name) return outputs def configure_optimizers(self): - raise NotImplementedError('Must be subclassed') + raise NotImplementedError("Must be subclassed") def training_step(self, batch, batch_idx): - return self.make_step(batch, step_name='train') + return self.make_step(batch, step_name="train") def validation_step(self, batch, batch_idx): - return self.make_step(batch, step_name='validation') - + return self.make_step(batch, step_name="validation") + def test_step(self, batch, batch_idx): - return self.make_step(batch, step_name='test') + return self.make_step(batch, step_name="test") # def on_train_batch_end(self, outputs, *args): # pass @@ -119,40 +134,58 @@ def test_step(self, batch, batch_idx): def on_train_epoch_end(self) -> None: # called *after* on_validation_epoch_end, confusingly - # do NOT log_all_metrics here. + # do NOT log_all_metrics here. # logging a metric resets it, and on_validation_epoch_end just logged and reset everything, so you will only log nans - self.log_all_metrics(subset='train') + self.log_all_metrics(subset="train") def on_validation_epoch_end(self) -> None: - self.log_all_metrics(subset='validation') + self.log_all_metrics(subset="validation") def on_test_epoch_end(self) -> None: # logging.info('start test epoch end') - self.log_all_metrics(subset='test') + self.log_all_metrics(subset="test") # logging.info('end test epoch end') - + def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name): - raise NotImplementedError('Must be subclassed') - + raise NotImplementedError("Must be subclassed") + def update_other_metrics(self, outputs, step_name): - raise NotImplementedError('Must be subclassed') + raise NotImplementedError("Must be subclassed") def log_all_metrics(self, subset=None): if subset is not None: - for metric_collection in (self.loss_metrics, self.question_loss_metrics, self.campaign_loss_metrics): + for metric_collection in ( + self.loss_metrics, + self.question_loss_metrics, + self.campaign_loss_metrics, + ): prog_bar = metric_collection == self.loss_metrics for name, metric in metric_collection.items(): if subset in name: # logging.info(name) - self.log(name, metric, on_epoch=True, on_step=False, prog_bar=prog_bar, logger=True) + self.log( + name, + metric, + on_epoch=True, + on_step=False, + prog_bar=prog_bar, + logger=True, + ) else: # just log everything - self.log_dict(self.loss_metrics, on_epoch=True, on_step=False, prog_bar=True, logger=True) - self.log_dict(self.question_loss_metrics, on_step=False, on_epoch=True, logger=True) - self.log_dict(self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True) - - + self.log_dict( + self.loss_metrics, + on_epoch=True, + on_step=False, + prog_bar=True, + logger=True, + ) + self.log_dict( + self.question_loss_metrics, on_step=False, on_epoch=True, logger=True + ) + self.log_dict( + self.campaign_loss_metrics, on_step=False, on_epoch=True, logger=True + ) - def predict_step(self, batch, batch_idx, dataloader_idx=0): # I can't work out how to get webdataset to return a single item im, not a tuple (im,). # this is fine for training but annoying for predict @@ -167,11 +200,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): class ZoobotTree(GenericLightningModule): """ - + The Zoobot model. Train from scratch using :func:`zoobot.pytorch.training.train_with_pytorch_lightning.train_default_zoobot_from_scratch`. PyTorch LightningModule describing how to train the encoder and head (described below). - Trains using Dirichlet loss. Labels should be num. volunteers giving each answer to each question. + Trains using Dirichlet loss. Labels should be num. volunteers giving each answer to each question. See the code for exact training step, logging, etc - there's a lot of detail. @@ -193,10 +226,10 @@ def __init__( # in the simplest case, this is all zoobot needs: grouping of label col indices as questions # question_index_groups: List=None, # BUT - # if you pass these, it enables better per-question and per-survey logging (because we have names) + # if you pass these, it enables better per-question and per-survey logging (because we have names) # must be passed as simple dicts, not objects, so can't just pass schema in - question_answer_pairs: dict=None, - dependencies: dict=None, + question_answer_pairs: dict = None, + dependencies: dict = None, # encoder args architecture_name="efficientnet_b0", channels=1, @@ -210,8 +243,8 @@ def __init__( # optim args betas=(0.9, 0.999), # PyTorch default weight_decay=0.01, # AdamW PyTorch default - scheduler_params={} # no scheduler by default - ): + scheduler_params={}, # no scheduler by default + ): # now, finally, can pass only standard variables as hparams to save # will still need to actually use these variables later, this super init only saves them @@ -229,10 +262,10 @@ def __init__( learning_rate, betas, weight_decay, - scheduler_params + scheduler_params, ) - logging.info('Generic __init__ complete - moving to Zoobot __init__') + logging.info("Generic __init__ complete - moving to Zoobot __init__") # logging.info('question_index_groups/dependencies passed to Zoobot, constructing schema in __init__') # assert question_index_groups is None, "Don't pass both question_index_groups and question_answer_pairs/dependencies" @@ -254,30 +287,28 @@ def __init__( architecture_name, channels, # use_imagenet_weights=use_imagenet_weights, - **timm_kwargs + **timm_kwargs, ) if compile_encoder: - logging.warning('Using torch.compile on encoder') + logging.warning("Using torch.compile on encoder") self.encoder = torch.compile(self.encoder) # bit lazy assuming 224 input size # logging.warning(channels) self.encoder_dim = get_encoder_dim(self.encoder, channels) # typically encoder_dim=1280 for effnetb0 - logging.info('encoder dim: {}'.format(self.encoder_dim)) - + logging.info("encoder dim: {}".format(self.encoder_dim)) self.head = get_pytorch_dirichlet_head( self.encoder_dim, output_dim=output_dim, test_time_dropout=test_time_dropout, - dropout_rate=dropout_rate + dropout_rate=dropout_rate, ) self.loss_func = get_dirichlet_loss_func(question_index_groups) - logging.info('Zoobot __init__ complete') - + logging.info("Zoobot __init__ complete") def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name): # self.loss_func returns shape of (galaxy, question), mean to () @@ -287,10 +318,9 @@ def calculate_loss_and_update_loss_metrics(self, predictions, labels, step_name) # for DDP strategy, batch size is constant (batches are not divided, data pool is divided) # so this will be the global per-example mean loss = torch.mean(torch.sum(multiq_loss, axis=1)) - self.loss_metrics[step_name + '/supervised_loss'](loss) + self.loss_metrics[step_name + "/supervised_loss"](loss) return loss - def configure_optimizers(self): # designed for training from scratch # parameters = list(self.head.parameters()) + list(self.encoder.parameters()) TODO should happen automatically? @@ -298,33 +328,40 @@ def configure_optimizers(self): self.parameters(), lr=self.learning_rate, betas=self.betas, - weight_decay=self.weight_decay - ) - if self.scheduler_params.get('name', None) == 'plateau': - logging.info(f'Using Plateau scheduler with {self.scheduler_params}') + weight_decay=self.weight_decay, + ) + if self.scheduler_params.get("name", None) == "plateau": + logging.info(f"Using Plateau scheduler with {self.scheduler_params}") # TODO could generalise this if needed scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, + optimizer, min_lr=1e-6, - patience=self.scheduler_params.get('patience', 5) + patience=self.scheduler_params.get("patience", 5), ) - return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'validation/loss'} - elif self.scheduler_params.get('cosine_schedule', False): - logging.info('Using cosine schedule') + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "validation/loss", + } + elif self.scheduler_params.get("cosine_schedule", False): + logging.info("Using cosine schedule") scheduler = schedulers.CosineWarmupScheduler( optimizer=optimizer, - warmup_epochs=self.scheduler_params['warmup_epochs'], - max_epochs=self.scheduler_params['max_cosine_epochs'], + warmup_epochs=self.scheduler_params["warmup_epochs"], + max_epochs=self.scheduler_params["max_cosine_epochs"], start_value=self.learning_rate, - end_value=self.learning_rate * self.scheduler_params['max_learning_rate_reduction_factor'] + end_value=self.learning_rate + * self.scheduler_params["max_learning_rate_reduction_factor"], ) - return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'validation/loss'} + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": "validation/loss", + } else: - logging.info('No scheduler used') + logging.info("No scheduler used") return optimizer # no scheduler - - def update_per_question_loss_metric(self, multiq_loss, step_name): # log questions individually # TODO need schema attribute or similar to have access to question names, this will do for now @@ -332,53 +369,68 @@ def update_per_question_loss_metric(self, multiq_loss, step_name): # TODO could use TorchMetrics and for q in schema, self.q_metric loop # if hasattr(self, 'schema'): - # use schema metadata to log intelligently - # will have schema if question_answer_pairs and dependencies are passed to __init__ - # assume that questions are named like smooth-or-featured-CAMPAIGN + # use schema metadata to log intelligently + # will have schema if question_answer_pairs and dependencies are passed to __init__ + # assume that questions are named like smooth-or-featured-CAMPAIGN for question_n, question in enumerate(self.schema.questions): # for logging comparison, want to ignore loss on unlablled examples, i.e. take mean ignoring zeros # could sum, but then this would vary with batch size - nontrivial_loss_mask = multiq_loss[:, question_n] > 0 # 'zero' seems to be ~5e-5 floor in practice + nontrivial_loss_mask = ( + multiq_loss[:, question_n] > 0 + ) # 'zero' seems to be ~5e-5 floor in practice - this_question_metric = self.question_loss_metrics[step_name + '/question_loss/' + question.text] + this_question_metric = self.question_loss_metrics[ + step_name + "/question_loss/" + question.text + ] # raise ValueError - this_question_metric(torch.mean(multiq_loss[nontrivial_loss_mask, question_n])) + this_question_metric( + torch.mean(multiq_loss[nontrivial_loss_mask, question_n]) + ) campaigns = schema_to_campaigns(self.schema) for campaign in campaigns: - campaign_questions = [q for q in self.schema.questions if campaign in q.text] - campaign_q_indices = [self.schema.questions.index(q) for q in campaign_questions] # shape (num q in this campaign e.g. 10) + campaign_questions = [ + q for q in self.schema.questions if campaign in q.text + ] + campaign_q_indices = [ + self.schema.questions.index(q) for q in campaign_questions + ] # shape (num q in this campaign e.g. 10) # similarly to per-question, only include in mean if (any) q in this campaign has a non-trivial loss - nontrivial_loss_mask = multiq_loss[:, campaign_q_indices].sum(axis=1) > 0 # shape batch size - - this_campaign_metric = self.campaign_loss_metrics[step_name + '/campaign_loss/' + campaign] - this_campaign_metric(torch.mean(multiq_loss[nontrivial_loss_mask][:, campaign_q_indices])) + nontrivial_loss_mask = ( + multiq_loss[:, campaign_q_indices].sum(axis=1) > 0 + ) # shape batch size + + this_campaign_metric = self.campaign_loss_metrics[ + step_name + "/campaign_loss/" + campaign + ] + this_campaign_metric( + torch.mean(multiq_loss[nontrivial_loss_mask][:, campaign_q_indices]) + ) # else: # # fallback to logging with question_n # for question_n in range(multiq_loss.shape[1]): # self.log(f'{step_name}/questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, sync_dist=True) - - - - def get_dirichlet_loss_func(question_index_groups): # This just adds schema.question_index_groups as an arg to the usual (labels, preds) loss arg format # Would use lambda but multi-gpu doesn't support as lambda can't be pickled return partial(dirichlet_loss, question_index_groups=question_index_groups) - # accept (labels, preds), return losses of shape (batch, question) + + def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=False): # pytorch convention is preds, labels for loss func # my and sklearn convention is labels, preds for loss func # multiquestion_loss returns loss of shape (batch, question) # torch.sum(multiquestion_loss, axis=1) gives loss of shape (batch). Equiv. to non-log product of question likelihoods. - multiq_loss = losses.calculate_multiquestion_loss(labels, preds, question_index_groups, careful=True) + multiq_loss = losses.calculate_multiquestion_loss( + labels, preds, question_index_groups, careful=True + ) if sum_over_questions: return torch.sum(multiq_loss, axis=1) else: @@ -390,24 +442,26 @@ def dirichlet_loss(preds, labels, question_index_groups, sum_over_questions=Fals def get_encoder_dim(encoder, channels=3): device = next(encoder.parameters()).device try: - x = torch.randn(2, channels, 224, 224, device=device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] except RuntimeError as e: - if 'channels instead' in str(e): - logging.info('encoder dim search failed on channels, trying with channels=1') + if "channels instead" in str(e): + logging.info( + "encoder dim search failed on channels, trying with channels=1" + ) channels = 1 - x = torch.randn(2, channels, 224, 224, device=device) # BCHW + x = torch.randn(2, channels, 224, 224, device=device) # BCHW return encoder(x).shape[-1] else: raise e - + def get_pytorch_encoder( - architecture_name='efficientnet_b0', + architecture_name="efficientnet_b0", channels=1, # use_imagenet_weights=False, - **timm_kwargs - ) -> nn.Module: + **timm_kwargs, +) -> nn.Module: """ Create a trainable efficientnet model. First layers are galaxy-appropriate augmentation layers - see :meth:`zoobot.estimators.define_model.add_augmentation_layers`. @@ -425,7 +479,7 @@ def get_pytorch_encoder( weights_loc (str, optional): If str, load weights from efficientnet checkpoint at this location. Defaults to None. include_top (bool, optional): If True, include head used for GZ DECaLS: global pooling and dense layer. Defaults to True. expect_partial (bool, optional): If True, do not raise partial match error when loading weights (likely for optimizer state). Defaults to False. - channels (int, default 1): Number of channels i.e. C in NHWC-dimension inputs. + channels (int, default 1): Number of channels i.e. C in NHWC-dimension inputs. Returns: torch.nn.Sequential: trainable efficientnet model including augmentations and optional head @@ -436,26 +490,32 @@ def get_pytorch_encoder( # if architecture_name == 'toy': # logging.warning('Using toy encoder') # return ToyEncoder() - + # support older code that didn't specify effnet version - if architecture_name == 'efficientnet': - logging.warning('efficientnet variant not specified - please set architecture_name=efficientnet_b0 (or similar)') - architecture_name = 'efficientnet_b0' - return timm.create_model(architecture_name, in_chans=channels, num_classes=0, **timm_kwargs) + if architecture_name == "efficientnet": + logging.warning( + "efficientnet variant not specified - please set architecture_name=efficientnet_b0 (or similar)" + ) + architecture_name = "efficientnet_b0" + return timm.create_model( + architecture_name, in_chans=channels, num_classes=0, **timm_kwargs + ) -def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float) -> torch.nn.Sequential: +def get_pytorch_dirichlet_head( + encoder_dim: int, output_dim: int, test_time_dropout: bool, dropout_rate: float +) -> torch.nn.Sequential: """ Head to combine with encoder (above) when predicting Galaxy Zoo decision tree answers. Pytorch Sequential model. Predicts Dirichlet concentration parameters. - + Also used when finetuning on a new decision tree - see :class:`zoobot.pytorch.training.finetune.FinetuneableZoobotTree`. Args: encoder_dim (int): dimensions of preceding encoder i.e. the input size expected by this submodel. output_dim (int): output dimensions of this head e.g. 34 to predict 34 answers. - test_time_dropout (bool): Use dropout at test time. + test_time_dropout (bool): Use dropout at test time. dropout_rate (float): P of dropout. See torch.nn.Dropout docs. Returns: @@ -465,29 +525,30 @@ def get_pytorch_dirichlet_head(encoder_dim: int, output_dim: int, test_time_drop modules_to_use = [] assert output_dim is not None - # no AdaptiveAvgPool2d, encoder assumed to pool already + # no AdaptiveAvgPool2d, encoder assumed to pool already if test_time_dropout: - logging.info('Using test-time dropout') + logging.info("Using test-time dropout") dropout_layer = custom_layers.PermaDropout else: - logging.info('Not using test-time dropout') + logging.info("Not using test-time dropout") dropout_layer = torch.nn.Dropout modules_to_use.append(dropout_layer(dropout_rate)) # TODO could optionally add a bottleneck layer here - modules_to_use.append(efficientnet_custom.custom_top_dirichlet(encoder_dim, output_dim)) + modules_to_use.append( + efficientnet_custom.custom_top_dirichlet(encoder_dim, output_dim) + ) return nn.Sequential(*modules_to_use) def schema_to_campaigns(schema): # e.g. [gz2, dr12, ...] - return [question.text.split('-')[-1] for question in schema.questions] + return [question.text.split("-")[-1] for question in schema.questions] -if __name__ == '__main__': +if __name__ == "__main__": encoder = get_pytorch_encoder(channels=1) dim = get_encoder_dim(encoder, channels=1) print(dim) - - ZoobotTree.load_from_checkpoint \ No newline at end of file + ZoobotTree.load_from_checkpoint diff --git a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py index 4cf7efff..e4b44394 100644 --- a/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py +++ b/zoobot/pytorch/examples/finetuning/finetune_binary_classification.py @@ -44,7 +44,7 @@ model = finetune.FinetuneableZoobotClassifier( name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano', num_classes=2, - n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper. + n_blocks=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper. ) # under the hood, this does: # encoder = finetune.load_pretrained_encoder(checkpoint_loc) diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 145e08df..4588d7f4 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -1,22 +1,22 @@ import logging import os -from typing import Any, Union, Optional -import warnings from functools import partial +from typing import Any, Optional, Union import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.callbacks import LearningRateMonitor - +import timm import torch import torch.nn.functional as F import torchmetrics as tm -import timm +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from sklearn import linear_model +from sklearn.metrics import accuracy_score -from zoobot.pytorch.training import losses, schedulers from zoobot.pytorch.estimators import define_model +from zoobot.pytorch.training import losses, schedulers from zoobot.shared import schemas # https://discuss.pytorch.org/t/how-to-freeze-bn-layers-while-training-the-rest-of-network-mean-and-var-wont-freeze/89736/7 @@ -24,9 +24,9 @@ def freeze_batchnorm_layers(model): - for name, child in (model.named_children()): + for name, child in model.named_children(): if isinstance(child, torch.nn.BatchNorm2d): - logging.debug('Freezing {} {}'.format(child, name)) + logging.debug("Freezing {} {}".format(child, name)) child.eval() # no grads, no param updates, no statistic updates else: freeze_batchnorm_layers(child) # recurse @@ -38,7 +38,7 @@ class FinetuneableZoobotAbstract(pl.LightningModule): You cannot use this class directly - you must use the child classes above instead. This class defines the shared finetuning args and methods used by those child classes. - For example: + For example: - When provided `name`, it will load the HuggingFace encoder with that name (see below for more). - When provided `learning_rate` it will set the optimizer to use that learning rate. @@ -53,8 +53,8 @@ class FinetuneableZoobotAbstract(pl.LightningModule): name (str, optional): Name of a model on HuggingFace Hub e.g.'hf_hub:mwalmsley/zoobot-encoder-convnext_nano'. Defaults to None. encoder (torch.nn.Module, optional): A PyTorch model already loaded in memory zoobot_checkpoint_loc (str, optional): Path to ZoobotTree lightning checkpoint to load. Loads with Load with :func:`zoobot.pytorch.training.finetune.load_pretrained_zoobot`. Defaults to None. - - n_blocks (int, optional): + + n_blocks (int, optional): lr_decay (float, optional): For each layer i below the head, reduce the learning rate by lr_decay ^ i. Defaults to 0.75. weight_decay (float, optional): AdamW weight decay arg (i.e. L2 penalty). Defaults to 0.05. learning_rate (float, optional): AdamW learning rate arg. Defaults to 1e-4. @@ -68,25 +68,20 @@ class FinetuneableZoobotAbstract(pl.LightningModule): prog_bar (bool, optional): Print progress bar during finetuning. Defaults to True. visualize_images (bool, optional): Upload example images to WandB. Good for debugging but slow. Defaults to False. seed (int, optional): random seed to use. Defaults to 42. - n_layers: No effect, deprecated. Use n_blocks instead. """ def __init__( self, - # load a pretrained timm encoder saved on huggingface hub # (aimed at most users, easiest way to load published models) name=None, - # ...or directly pass any model to use as encoder (if you do this, you will need to keep it around for later) # (aimed at tinkering with new architectures e.g. SSL) encoder=None, # use any torch model already loaded in memory (must have .forward() method) - # load a pretrained zoobottree model and grab the encoder (a timm model) # requires the exact same zoobot version used for training, not very portable # (aimed at supervised experiments) - zoobot_checkpoint_loc=None, - + zoobot_checkpoint_loc=None, # finetuning settings n_blocks=0, # how many layers deep to FT lr_decay=0.75, @@ -94,7 +89,6 @@ def __init__( learning_rate=1e-4, # 10x lower than typical, you may like to experiment dropout_prob=0.5, always_train_batchnorm=False, # temporarily deprecated - # n_layers=0, # for backward compat., n_blocks preferred. Now removed in v2. # these args are for the optional learning rate scheduler, best not to use unless you've tuned everything else already cosine_schedule=False, warmup_epochs=0, @@ -105,8 +99,8 @@ def __init__( # debugging utils prog_bar=True, visualize_images=False, # upload examples to wandb, good for debugging + n_layers=0, # deprecated (no effect) but can't remove yet as is an arg in some saved checkpoints seed=42, - n_layers=None, # deprecated, no effect ): super().__init__() @@ -114,15 +108,15 @@ def __init__( # will also add to wandb if using logging=wandb, I think # necessary if you want to reload! # with warnings.catch_warnings(): - # warnings.simplefilter("ignore") - # this raises a warning that encoder is already a Module hence saved in checkpoint hence no need to save as hparam - # true - except we need it to instantiate this class, so it's really handy to have saved as well - # therefore ignore the warning - self.save_hyperparameters(ignore=['encoder']) # never serialise the encoder, way too heavy - # if you need the encoder to recreate, pass when loading checkpoint e.g. - # FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder) - - if name is not None: + # warnings.simplefilter("ignore") + # this raises a warning that encoder is already a Module hence saved in checkpoint hence no need to save as hparam + # true - except we need it to instantiate this class, so it's really handy to have saved as well + # therefore ignore the warning + self.save_hyperparameters(ignore=["encoder"]) # never serialise the encoder, way too heavy + # if you need the encoder to recreate, pass when loading checkpoint e.g. + # FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder) + + if name is not None: # will load from Hub assert encoder is None, 'Cannot pass both name and encoder to use' if 'greyscale' in name: # I'm not sure why timm is happy to convert color model stem to greyscale @@ -134,16 +128,26 @@ def __init__( self.encoder = timm.create_model(name, num_classes=0, pretrained=True, **timm_kwargs) self.encoder_dim = self.encoder.num_features - elif zoobot_checkpoint_loc is not None: - assert encoder is None, 'Cannot pass both checkpoint to load and encoder to use' - self.encoder = load_pretrained_zoobot(zoobot_checkpoint_loc) # extracts the timm encoder + elif zoobot_checkpoint_loc is not None: # will load from local checkpoint + assert encoder is None, "Cannot pass both checkpoint to load and encoder to use" + self.encoder = load_pretrained_zoobot( + zoobot_checkpoint_loc + ) # extracts the timm encoder self.encoder_dim = self.encoder.num_features - else: - assert zoobot_checkpoint_loc is None, 'Cannot pass both checkpoint to load and encoder to use' - assert encoder is not None, 'Must pass either checkpoint to load or encoder to use' + + else: # passed encoder in-memory directly + assert ( + zoobot_checkpoint_loc is None + ), "Cannot pass both checkpoint to load and encoder to use" + assert encoder is not None, "Must pass either checkpoint to load or encoder to use" self.encoder = encoder - # work out encoder dim 'manually' - self.encoder_dim = define_model.get_encoder_dim(self.encoder) + # find out encoder dimension + if hasattr(self.encoder, 'num_features'): # timm models generally use this + self.encoder_dim = self.encoder.num_features + elif hasattr(self.encoder, 'embed_dim'): # timm.models.VisionTransformer uses this + self.encoder_dim = self.encoder.embed_dim + else: # resort to manual estimate + self.encoder_dim = define_model.get_encoder_dim(self.encoder) self.n_blocks = n_blocks @@ -161,7 +165,9 @@ def __init__( self.always_train_batchnorm = always_train_batchnorm if self.always_train_batchnorm: - raise NotImplementedError('Temporarily deprecated, always_train_batchnorm=True not supported') + raise NotImplementedError( + "Temporarily deprecated, always_train_batchnorm=True not supported" + ) # logging.info('always_train_batchnorm=True, so all batch norm layers will be finetuned') self.train_loss_metric = tm.MeanMetric() @@ -172,16 +178,24 @@ def __init__( self.prog_bar = prog_bar self.visualize_images = visualize_images - def configure_optimizers(self): + # Remove ViT head if it exists + if hasattr(self.encoder, "head") and isinstance( + self.encoder, timm.models.VisionTransformer + ): + # If the encoder has a 'head' attribute, replace it with Identity() + self.encoder.head = torch.nn.Identity() + logging.info("Replaced encoder.head with Identity()") + + def configure_optimizers(self): """ This controls which parameters get optimized self.head is always optimized, with no learning rate decay when self.n_blocks == 0, only self.head is optimized (i.e. frozen* encoder) - + for self.encoder, we enumerate the blocks (groups of layers) to potentially finetune and then pick the top self.n_blocks to finetune - + weight_decay is applied to both the head and (if relevant) the encoder learning rate decay is applied to the encoder only: lr x (lr_decay^block_n), ignoring the head (block 0) @@ -193,14 +207,16 @@ def configure_optimizers(self): lr = self.learning_rate params = [{"params": self.head.parameters(), "lr": lr}] - logging.info(f'Encoder architecture to finetune: {type(self.encoder)}') + logging.info(f"Encoder architecture to finetune: {type(self.encoder)}") if self.from_scratch: - logging.warning('self.from_scratch is True, training everything and ignoring all settings') + logging.warning( + "self.from_scratch is True, training everything and ignoring all settings" + ) params += [{"params": self.encoder.parameters(), "lr": lr}] return torch.optim.AdamW(params, weight_decay=self.weight_decay) - if isinstance(self.encoder, timm.models.EfficientNet): # includes v2 + if isinstance(self.encoder, timm.models.EfficientNet): # includes v2 # TODO for now, these count as separate layers, not ideal early_tuneable_layers = [self.encoder.conv_stem, self.encoder.bn1] encoder_blocks = list(self.encoder.blocks) @@ -214,38 +230,44 @@ def configure_optimizers(self): self.encoder.layer1, self.encoder.layer2, self.encoder.layer3, - self.encoder.layer4 + self.encoder.layer4, ] elif isinstance(self.encoder, timm.models.MaxxVit): tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages] elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 4 blocks, for all sizes # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264 tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages] + elif isinstance(self.encoder, timm.models.VisionTransformer): + tuneable_blocks = [self.encoder.patch_embed] + [block for block in self.encoder.blocks] else: - raise ValueError(f'Encoder architecture not automatically recognised: {type(self.encoder)}') - + raise ValueError( + f"Encoder architecture not automatically recognised: {type(self.encoder)}" + ) + assert self.n_blocks <= len( tuneable_blocks ), f"Network only has {len(tuneable_blocks)} tuneable blocks, {self.n_blocks} specified for finetuning" - # take n blocks, ordered highest layer to lowest layer tuneable_blocks.reverse() - logging.info('possible blocks to tune: {}'.format(len(tuneable_blocks))) + logging.info(f"possible blocks to tune: {len(tuneable_blocks)}") + # will finetune all params in first N - logging.info('blocks that will be tuned: {}'.format(self.n_blocks)) - blocks_to_tune = tuneable_blocks[:self.n_blocks] + logging.info(f"blocks that will be tuned: {self.n_blocks}") + blocks_to_tune = tuneable_blocks[: self.n_blocks] + # optionally, can finetune batchnorm params in remaining layers - remaining_blocks = tuneable_blocks[self.n_blocks:] - logging.info('Remaining blocks: {}'.format(len(remaining_blocks))) - assert not any([block in remaining_blocks for block in blocks_to_tune]), 'Some blocks are in both tuneable and remaining' + remaining_blocks = tuneable_blocks[self.n_blocks :] + logging.info(f"Remaining blocks: {len(remaining_blocks)}") + + assert not any( + [block in remaining_blocks for block in blocks_to_tune] + ), "Some blocks are in both tuneable and remaining" # Append parameters of layers for finetuning along with decayed learning rate for i, block in enumerate(blocks_to_tune): # _ is the block name e.g. '3' - params.append({ - "params": block.parameters(), - "lr": lr * (self.lr_decay**i) - }) + logging.info(f"Adding block {block} with lr {lr * (self.lr_decay**i)}") + params.append({"params": block.parameters(), "lr": lr * (self.lr_decay**i)}) # optionally, for the remaining layers (not otherwise finetuned) you can choose to still FT the batchnorm layers for i, block in enumerate(remaining_blocks): @@ -257,8 +279,7 @@ def configure_optimizers(self): # "lr": lr * (self.lr_decay**i) # }) - - logging.info('param groups: {}'.format(len(params))) + logging.info(f"param groups: {len(params)}") # because it iterates through the generators, THIS BREAKS TRAINING so only uncomment to debug params # for param_group_n, param_group in enumerate(params): @@ -269,11 +290,17 @@ def configure_optimizers(self): # exit() # Initialize AdamW optimizer - opt = torch.optim.AdamW(params, weight_decay=self.weight_decay) # lr included in params dict - logging.info('Optimizer ready, configuring scheduler') + opt = torch.optim.AdamW( + params, weight_decay=self.weight_decay + ) # lr included in params dict + logging.info("Optimizer ready, configuring scheduler") if self.cosine_schedule: - logging.info('Using lightly cosine schedule, warmup for {} epochs, max for {} epochs'.format(self.warmup_epochs, self.max_cosine_epochs)) + logging.info( + "Using lightly cosine schedule, warmup for {} epochs, max for {} epochs".format( + self.warmup_epochs, self.max_cosine_epochs + ) + ) # from lightly.utils.scheduler import CosineWarmupScheduler #copied from here to avoid dependency # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers # Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config. @@ -295,15 +322,14 @@ def configure_optimizers(self): return { "optimizer": opt, "lr_scheduler": { - 'scheduler': lr_scheduler, - 'interval': 'epoch', - 'frequency': 1 - } + "scheduler": lr_scheduler, + "interval": "epoch", + "frequency": 1, + }, } else: - logging.info('Learning rate scheduler not used') + logging.info("Learning rate scheduler not used") return opt - def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.encoder(x) @@ -316,14 +342,15 @@ def make_step(self, batch): return self.step_to_dict(y, y_pred, loss) def run_step_through_model(self, batch): - # part of training/val/test for all subclasses + # part of training/val/test for all subclasses x, y = batch y_pred = self.forward(x) loss = self.loss(y_pred, y) # must be subclasses and specified + loss.float() return y, y_pred, loss def step_to_dict(self, y, y_pred, loss): - return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y} + return {"loss": loss.mean(), "predictions": y_pred, "labels": y} def training_step(self, batch, batch_idx, dataloader_idx=0): return self.make_step(batch) @@ -333,7 +360,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): def test_step(self, batch, batch_idx, dataloader_idx=0): return self.make_step(batch) - + def predict_step(self, batch, batch_idx) -> Any: # I can't work out how to get webdataset to return a single item im, not a tuple (im,). # this is fine for training but annoying for predict @@ -347,52 +374,52 @@ def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0): # arg is shown for val/test equivalents # currently does nothing in Zoobot so inconsequential # https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#on-train-batch-end - self.train_loss_metric(outputs['loss']) + self.train_loss_metric(outputs["loss"]) self.log( - "finetuning/train_loss", - self.train_loss_metric, - prog_bar=self.prog_bar, + "finetuning/train_loss", + self.train_loss_metric, + prog_bar=self.prog_bar, on_step=False, - on_epoch=True + on_epoch=True, ) def on_validation_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx=0): - self.val_loss_metric(outputs['loss']) + self.val_loss_metric(outputs["loss"]) self.log( - "finetuning/val_loss", - self.val_loss_metric, - prog_bar=self.prog_bar, + "finetuning/val_loss", + self.val_loss_metric, + prog_bar=self.prog_bar, on_step=False, - on_epoch=True + on_epoch=True, ) # unique to val batch end if self.visualize_images: - self.upload_images_to_wandb(outputs, batch, batch_idx) + self.upload_images_to_wandb(outputs, batch, batch_idx) def on_test_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx=0): - self.test_loss_metric(outputs['loss']) + self.test_loss_metric(outputs["loss"]) self.log( - "finetuning/test_loss", - self.test_loss_metric, - prog_bar=self.prog_bar, + "finetuning/test_loss", + self.test_loss_metric, + prog_bar=self.prog_bar, on_step=False, - on_epoch=True + on_epoch=True, ) -# lighting v2. removed validation_epoch_end(self, outputs) -# now only has *on_*validation_epoch_end(self) -# replacing by using explicit torchmetric for loss -# https://github.com/Lightning-AI/lightning/releases/tag/2.0.0 + # lighting v2. removed validation_epoch_end(self, outputs) + # now only has *on_*validation_epoch_end(self) + # replacing by using explicit torchmetric for loss + # https://github.com/Lightning-AI/lightning/releases/tag/2.0.0 def upload_images_to_wandb(self, outputs, batch, batch_idx): - raise NotImplementedError('Must be subclassed') - + raise NotImplementedError("Must be subclassed") + @classmethod def load_from_name(cls, name: str, **kwargs): downloaded_loc = download_from_name(cls.__name__, name) - return cls.load_from_checkpoint(downloaded_loc, **kwargs) # trained on GPU, may need map_location='cpu' if you get a device error - - + return cls.load_from_checkpoint( + downloaded_loc, **kwargs + ) # trained on GPU, may need map_location='cpu' if you get a device error @@ -411,80 +438,87 @@ class FinetuneableZoobotClassifier(FinetuneableZoobotAbstract): num_classes (int): num. of target classes (e.g. 2 for binary classification). label_smoothing (float, optional): See torch cross_entropy_loss docs. Defaults to 0. class_weights (arraylike, optional): See torch cross_entropy_loss docs. Defaults to None. - + run_linear_sanity_check (bool, optional): Before fitting, use sklearn to fit a linear model. Defaults to False. + """ def __init__( - self, - num_classes: int, - label_smoothing=0., - class_weights=None, - **super_kwargs) -> None: + self, num_classes: int, label_smoothing=0.0, class_weights=None, run_linear_sanity_check=False, **super_kwargs + ) -> None: super().__init__(**super_kwargs) - logging.info('Using classification head and cross-entropy loss') + logging.info("Using classification head and cross-entropy loss") self.head = LinearHead( input_dim=self.encoder_dim, output_dim=num_classes, - dropout_prob=self.dropout_prob + dropout_prob=self.dropout_prob, ) self.label_smoothing = label_smoothing - self.loss = partial(cross_entropy_loss, - weight=class_weights, - label_smoothing=self.label_smoothing) - logging.info(f'num_classes: {num_classes}') + self.loss = partial( + cross_entropy_loss, + weight=class_weights, + label_smoothing=self.label_smoothing, + ) + logging.info(f"num_classes: {num_classes}") if num_classes == 2: - logging.info('Using binary classification') - task = 'binary' + logging.info("Using binary classification") + task = "binary" else: - logging.info('Using multi-class classification') - task = 'multiclass' + logging.info("Using multi-class classification") + task = "multiclass" self.train_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes) self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes) self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes) - + + self.run_linear_sanity_check = run_linear_sanity_check + def step_to_dict(self, y, y_pred, loss): - y_class_preds = torch.argmax(y_pred, axis=1) # type: ignore - return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y, 'class_predictions': y_class_preds} + y_class_preds = torch.argmax(y_pred, axis=1) # type: ignore + return { + "loss": loss.mean(), + "predictions": y_pred, + "labels": y, + "class_predictions": y_class_preds, + } + def on_train_batch_end(self, step_output, *args): super().on_train_batch_end(step_output, *args) - self.train_acc(step_output['class_predictions'], step_output['labels']) + self.train_acc(step_output["class_predictions"], step_output["labels"]) self.log( - 'finetuning/train_acc', + "finetuning/train_acc", self.train_acc, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) - + def on_validation_batch_end(self, step_output, *args): super().on_validation_batch_end(step_output, *args) - self.val_acc(step_output['class_predictions'], step_output['labels']) + self.val_acc(step_output["class_predictions"], step_output["labels"]) self.log( - 'finetuning/val_acc', + "finetuning/val_acc", self.val_acc, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) def on_test_batch_end(self, step_output, *args) -> None: super().on_test_batch_end(step_output, *args) - self.test_acc(step_output['class_predictions'], step_output['labels']) + self.test_acc(step_output["class_predictions"], step_output["labels"]) self.log( "finetuning/test_acc", self.test_acc, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) - def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): # see Abstract version if isinstance(x, list) and len(x) == 1: @@ -493,26 +527,56 @@ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): # then applies softmax return F.softmax(x, dim=1) - def upload_images_to_wandb(self, outputs, batch, batch_idx): - # self.logger is set by pl.Trainer(logger=) argument + # self.logger is set by pl.Trainer(logger=) argument if (self.logger is not None) and (batch_idx == 0): x, y = batch - y_pred_softmax = F.softmax(outputs['predictions'], dim=1) + y_pred_softmax = F.softmax(outputs["predictions"], dim=1) n_images = 5 images = [img for img in x[:n_images]] - captions = [f'Ground Truth: {y_i} \nPrediction: {y_p_i}' for y_i, y_p_i in zip( - y[:n_images], y_pred_softmax[:n_images])] - self.logger.log_image( # type: ignore - key='val_images', - images=images, - caption=captions) - - + captions = [ + f"Ground Truth: {y_i} \nPrediction: {y_p_i}" + for y_i, y_p_i in zip(y[:n_images], y_pred_softmax[:n_images]) + ] + self.logger.log_image(key="val_images", images=images, caption=captions) # type: ignore + + # Sanity check embeddings with linear evaluation first + def on_train_start(self) -> None: + if self.run_linear_sanity_check: # default False + self.linear_sanity_check() + + def linear_sanity_check(self): + # only implemented on Zoobot...Classifier as assumes accuracy + with torch.no_grad(): + embeddings, labels = {"train": [], "val": []}, {"train": [], "val": []} + + # Get validation set embeddings + for x, y in self.trainer.datamodule.val_dataloader(): + embeddings["val"] += self.encoder(x.to(self.device)).cpu() + labels["val"] += y + + # Get train set embeddings + for x, y in self.trainer.datamodule.train_dataloader(): + embeddings["train"] += self.encoder(x.to(self.device)).cpu() + labels["train"] += y + + # this is linear *train* acc but that's okay, simply test of features + model = linear_model.LogisticRegression(penalty=None, max_iter=200) + model.fit(embeddings["train"], labels["train"]) + + self.log( + "finetuning/linear_eval/val", + accuracy_score(labels["val"], model.predict(embeddings["val"])), + ) + self.log( + "finetuning/linear_eval/train", + accuracy_score(labels["train"], model.predict(embeddings["train"])), + ) + # doesn't need to be torchmetric, only happens in one go? but distributed class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract): """ - Pretrained Zoobot model intended for finetuning on a regression problem. + Pretrained Zoobot model intended for finetuning on a regression problem. Any args not listed below are passed to :class:``FinetuneableZoobotAbstract`` (for example, `learning_rate`). These are shared between classifier, regressor, and tree models. @@ -525,83 +589,78 @@ class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract): Args: loss (str, optional): Loss function to use. Must be one of 'mse', 'mae'. Defaults to 'mse'. unit_interval (bool, optional): If True, use sigmoid activation for the final layer, ensuring predictions between 0 and 1. Defaults to False. - + """ - def __init__( - self, - loss:str='mse', - unit_interval:bool=False, - **super_kwargs) -> None: + def __init__(self, loss: str = "mse", unit_interval: bool = False, **super_kwargs) -> None: super().__init__(**super_kwargs) self.unit_interval = unit_interval if self.unit_interval: - logging.info('unit_interval=True, using sigmoid activation for finetunng head') + logging.info("unit_interval=True, using sigmoid activation for finetunng head") head_activation = torch.nn.functional.sigmoid else: head_activation = None - - logging.info('Using classification head and cross-entropy loss') + + logging.info("Using classification head and cross-entropy loss") self.head = LinearHead( input_dim=self.encoder_dim, output_dim=1, dropout_prob=self.dropout_prob, - activation=head_activation + activation=head_activation, ) - if loss in ['mse', 'mean_squared_error']: + if loss in ["mse", "mean_squared_error"]: self.loss = mse_loss - elif loss in ['mae', 'mean_absolute_error', 'l1', 'l1_loss']: + elif loss in ["mae", "mean_absolute_error", "l1", "l1_loss"]: self.loss = l1_loss else: - raise ValueError(f'Loss {loss} not recognised. Must be one of mse, mae') + raise ValueError(f"Loss {loss} not recognised. Must be one of mse, mae") # rmse metrics. loss is mse already. self.train_rmse = tm.MeanSquaredError(squared=False) self.val_rmse = tm.MeanSquaredError(squared=False) self.test_rmse = tm.MeanSquaredError(squared=False) - + def step_to_dict(self, y, y_pred, loss): - return {'loss': loss.mean(), 'predictions': y_pred, 'labels': y} + return {"loss": loss.mean(), "predictions": y_pred, "labels": y} def on_train_batch_end(self, step_output, *args): super().on_train_batch_end(step_output, *args) - self.train_rmse(step_output['predictions'], step_output['labels']) + self.train_rmse(step_output["predictions"], step_output["labels"]) self.log( - 'finetuning/train_rmse', + "finetuning/train_rmse", self.train_rmse, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) - + def on_validation_batch_end(self, step_output, *args): super().on_validation_batch_end(step_output, *args) - self.val_rmse(step_output['predictions'], step_output['labels']) + self.val_rmse(step_output["predictions"], step_output["labels"]) self.log( - 'finetuning/val_rmse', + "finetuning/val_rmse", self.val_rmse, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) def on_test_batch_end(self, step_output, *args) -> None: super().on_test_batch_end(step_output, *args) - self.test_rmse(step_output['predictions'], step_output['labels']) + self.test_rmse(step_output["predictions"], step_output["labels"]) self.log( "finetuning/test_rmse", self.test_rmse, on_step=False, on_epoch=True, - prog_bar=self.prog_bar + prog_bar=self.prog_bar, ) - def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): # see Abstract version if isinstance(x, list) and len(x) == 1: @@ -611,9 +670,9 @@ def predict_step(self, x: Union[list[torch.Tensor], torch.Tensor], batch_idx): class FinetuneableZoobotTree(FinetuneableZoobotAbstract): """ - Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. + Pretrained Zoobot model intended for finetuning on a decision tree (i.e. GZ-like) problem. Uses Dirichlet-Multinomial loss introduced in GZ DECaLS. - Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, + Briefly: predicts a Dirichlet distribution for the probability of a typical volunteer giving each answer, and uses the Dirichlet-Multinomial loss to compare the predicted distribution of votes (given k volunteers were asked) to the true distribution. Does not produce accuracy or MSE metrics, as these are not relevant for this task. Loss logging only. @@ -631,15 +690,11 @@ class FinetuneableZoobotTree(FinetuneableZoobotAbstract): schema (schemas.Schema): description of the layout of the decision tree. See :class:`zoobot.shared.schemas.Schema`. """ - def __init__( - self, - schema: schemas.Schema, - **super_kwargs - ): + def __init__(self, schema: schemas.Schema, **super_kwargs): super().__init__(**super_kwargs) - logging.info('Using dropout+dirichlet head and dirichlet (count) loss') + logging.info("Using dropout+dirichlet head and dirichlet (count) loss") self.schema = schema self.output_dim = len(self.schema.label_cols) @@ -648,16 +703,17 @@ def __init__( encoder_dim=self.encoder_dim, output_dim=self.output_dim, test_time_dropout=False, - dropout_rate=self.dropout_prob + dropout_rate=self.dropout_prob, ) - + self.loss = define_model.get_dirichlet_loss_func(self.schema.question_index_groups) def upload_images_to_wandb(self, outputs, batch, batch_idx): - raise NotImplementedError + raise NotImplementedError # other functions are simply inherited from FinetunedZoobotAbstract + class LinearHead(torch.nn.Module): def __init__(self, input_dim: int, output_dim: int, dropout_prob=0.5, activation=None): """ @@ -691,7 +747,7 @@ def forward(self, x): Returns: torch.Tensor: result (see docstring of LinearHead) """ - # + # x = self.dropout(x) x = self.linear(x) if self.activation is not None: @@ -702,8 +758,9 @@ def forward(self, x): return x - -def cross_entropy_loss(y_pred: torch.Tensor, y: torch.Tensor, label_smoothing: float=0., weight=None): +def cross_entropy_loss( + y_pred: torch.Tensor, y: torch.Tensor, label_smoothing: float = 0.0, weight=None +): """ Calculate cross-entropy loss with optional label smoothing and class weights. No aggregation applied. Trivial wrapper of torch.nn.functional.cross_entropy with reduction='none'. @@ -717,7 +774,13 @@ def cross_entropy_loss(y_pred: torch.Tensor, y: torch.Tensor, label_smoothing: f Returns: torch.Tensor: unreduced cross-entropy loss """ - return F.cross_entropy(y_pred, y.long(), label_smoothing=label_smoothing, weight=weight, reduction='none') + return F.cross_entropy( + y_pred, + y.long(), + label_smoothing=label_smoothing, + weight=weight, + reduction="none", + ) def mse_loss(y_pred, y): @@ -731,7 +794,8 @@ def mse_loss(y_pred, y): Returns: torch.Tensor: See docstring of torch.nn.functional.mse_loss. """ - return F.mse_loss(y_pred, y, reduction='none') + return F.mse_loss(y_pred, y, reduction="none") + def l1_loss(y_pred, y): """ @@ -744,7 +808,7 @@ def l1_loss(y_pred, y): Returns: torch.Tensor: See docstring of torch.nn.functional.l1_loss. """ - return F.l1_loss(y_pred, y, reduction='none') + return F.l1_loss(y_pred, y, reduction="none") def dirichlet_loss(y_pred: torch.Tensor, y: torch.Tensor, question_index_groups): @@ -762,8 +826,9 @@ def dirichlet_loss(y_pred: torch.Tensor, y: torch.Tensor, question_index_groups) torch.Tensor: Dirichlet-Multinomial loss. Scalar, summing across answers and taking a mean across the batch i.e. sum(axis=1).mean()) """ # my func uses sklearn convention y, y_pred - return losses.calculate_multiquestion_loss(y, y_pred, question_index_groups).mean()*len(question_index_groups) - + return losses.calculate_multiquestion_loss(y, y_pred, question_index_groups).mean() * len( + question_index_groups + ) def load_pretrained_zoobot(checkpoint_loc: str) -> torch.nn.Module: @@ -778,9 +843,9 @@ def load_pretrained_zoobot(checkpoint_loc: str) -> torch.nn.Module: map_location = None else: # necessary to load gpu-trained model on cpu - map_location = torch.device('cpu') - return define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc, map_location=map_location).encoder # type: ignore - + map_location = torch.device("cpu") + return define_model.ZoobotTree.load_from_checkpoint(checkpoint_loc, map_location=map_location).encoder # type: ignore + def get_trainer( save_dir: str, @@ -788,10 +853,10 @@ def get_trainer( save_top_k=1, max_epochs=100, patience=10, - devices='auto', - accelerator='auto', + devices="auto", + accelerator="auto", logger=None, - **trainer_kwargs + **trainer_kwargs, ) -> pl.Trainer: """ Convenience wrapper to create a PyTorch Lightning Trainer that carries out the finetuning process. @@ -821,29 +886,31 @@ def get_trainer( """ checkpoint_callback = ModelCheckpoint( - monitor='finetuning/val_loss', + monitor="finetuning/val_loss", every_n_epochs=1, save_on_train_epoch_end=True, auto_insert_metric_name=False, verbose=True, - dirpath=os.path.join(save_dir, 'checkpoints'), + dirpath=os.path.join(save_dir, "checkpoints"), filename=file_template, save_weights_only=True, - save_top_k=save_top_k + save_top_k=save_top_k, ) early_stopping_callback = EarlyStopping( - monitor='finetuning/val_loss', - mode='min', - patience=patience + monitor="finetuning/val_loss", mode="min", patience=patience ) - learning_rate_monitor_callback = LearningRateMonitor(logging_interval='epoch') + learning_rate_monitor_callback = LearningRateMonitor(logging_interval="epoch") # Initialise pytorch lightning trainer trainer = pl.Trainer( logger=logger, - callbacks=[checkpoint_callback, early_stopping_callback, learning_rate_monitor_callback], + callbacks=[ + checkpoint_callback, + early_stopping_callback, + learning_rate_monitor_callback, + ], max_epochs=max_epochs, accelerator=accelerator, devices=devices, @@ -871,13 +938,10 @@ def download_from_name(class_name: str, hub_name: str): """ from huggingface_hub import hf_hub_download - if hub_name.startswith('hf_hub:'): - logging.info('Passed name with hf_hub: prefix, dropping prefix') - repo_id = hub_name.split('hf_hub:')[1] + if hub_name.startswith("hf_hub:"): + logging.info("Passed name with hf_hub: prefix, dropping prefix") + repo_id = hub_name.split("hf_hub:")[1] else: repo_id = hub_name - downloaded_loc = hf_hub_download( - repo_id=repo_id, - filename=f"{class_name}.ckpt" - ) + downloaded_loc = hf_hub_download(repo_id=repo_id, filename=f"{class_name}.ckpt") return downloaded_loc diff --git a/zoobot/pytorch/training/train_with_pytorch_lightning.py b/zoobot/pytorch/training/train_with_pytorch_lightning.py index 2c9e7524..4cf47b04 100644 --- a/zoobot/pytorch/training/train_with_pytorch_lightning.py +++ b/zoobot/pytorch/training/train_with_pytorch_lightning.py @@ -304,26 +304,11 @@ def train_default_zoobot_from_scratch( lightning_model = TorchSyncBatchNorm().apply(lightning_model) - extra_callbacks = extra_callbacks if extra_callbacks else [] - - monitor_metric = 'validation/supervised_loss' # used later for checkpoint_callback.best_model_path - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(save_dir, 'checkpoints'), - monitor=monitor_metric, - save_weights_only=True, - mode='min', - # custom filename for checkpointing due to / in metric - filename=checkpoint_file_template, - # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name - # avoids extra folders from the checkpoint name - auto_insert_metric_name=auto_insert_metric_name, - save_top_k=save_top_k - ) - - early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True) - callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks + checkpoint_callback, callbacks = get_default_callbacks(save_dir, patience, checkpoint_file_template, auto_insert_metric_name, save_top_k) + if extra_callbacks: + callbacks += extra_callbacks trainer = pl.Trainer( num_sanity_val_steps=0, @@ -368,6 +353,27 @@ def train_default_zoobot_from_scratch( return lightning_model, trainer +def get_default_callbacks(save_dir, patience=8, checkpoint_file_template=None, auto_insert_metric_name=True, save_top_k=3): + + monitor_metric = 'validation/supervised_loss' + + checkpoint_callback = ModelCheckpoint( + dirpath=os.path.join(save_dir, 'checkpoints'), + monitor=monitor_metric, + save_weights_only=True, + mode='min', + # custom filename for checkpointing due to / in metric + filename=checkpoint_file_template, + # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name + # avoids extra folders from the checkpoint name + auto_insert_metric_name=auto_insert_metric_name, + save_top_k=save_top_k + ) + + early_stopping_callback = EarlyStopping(monitor=monitor_metric, patience=patience, check_finite=True) + callbacks = [checkpoint_callback, early_stopping_callback] + return checkpoint_callback,callbacks + diff --git a/zoobot/shared/schemas.py b/zoobot/shared/schemas.py index b0123fc3..1d58dcf6 100755 --- a/zoobot/shared/schemas.py +++ b/zoobot/shared/schemas.py @@ -296,6 +296,7 @@ def answers(self): # so don't log anything during Schema.__init__! gz_evo_v1_schema = Schema(label_metadata.gz_evo_v1_pairs, label_metadata.gz_evo_v1_dependencies) +gz_evo_v1_public_schema = Schema(label_metadata.gz_evo_v1_public_pairs, label_metadata.gz_evo_v1_public_dependencies) gz_ukidss_schema = Schema(label_metadata.ukidss_ortho_pairs, label_metadata.ukidss_ortho_dependencies) gz_jwst_schema = Schema(label_metadata.jwst_ortho_pairs, label_metadata.jwst_ortho_dependencies)