Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
inigoval committed May 23, 2024
1 parent a58d72b commit bb58215
Showing 1 changed file with 17 additions and 45 deletions.
62 changes: 17 additions & 45 deletions zoobot/pytorch/training/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def __init__(
# this raises a warning that encoder is already a Module hence saved in checkpoint hence no need to save as hparam
# true - except we need it to instantiate this class, so it's really handy to have saved as well
# therefore ignore the warning
self.save_hyperparameters(
ignore=["encoder"]
) # never serialise the encoder, way too heavy
self.save_hyperparameters(ignore=["encoder"]) # never serialise the encoder, way too heavy
# if you need the encoder to recreate, pass when loading checkpoint e.g.
# FinetuneableZoobotTree.load_from_checkpoint(loc, encoder=encoder)

Expand All @@ -125,9 +123,7 @@ def __init__(
self.encoder_dim = self.encoder.num_features

elif zoobot_checkpoint_loc is not None:
assert (
encoder is None
), "Cannot pass both checkpoint to load and encoder to use"
assert encoder is None, "Cannot pass both checkpoint to load and encoder to use"
self.encoder = load_pretrained_zoobot(
zoobot_checkpoint_loc
) # extracts the timm encoder
Expand All @@ -137,9 +133,7 @@ def __init__(
assert (
zoobot_checkpoint_loc is None
), "Cannot pass both checkpoint to load and encoder to use"
assert (
encoder is not None
), "Must pass either checkpoint to load or encoder to use"
assert encoder is not None, "Must pass either checkpoint to load or encoder to use"
self.encoder = encoder
# work out encoder dim 'manually'
if isinstance(self.encoder, timm.models.VisionTransformer):
Expand Down Expand Up @@ -231,20 +225,12 @@ def configure_optimizers(self):
self.encoder.layer4,
]
elif isinstance(self.encoder, timm.models.MaxxVit):
tuneable_blocks = [self.encoder.stem] + [
stage for stage in self.encoder.stages
]
elif isinstance(
self.encoder, timm.models.ConvNeXt
): # stem + 4 blocks, for all sizes
tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
elif isinstance(self.encoder, timm.models.ConvNeXt): # stem + 4 blocks, for all sizes
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py#L264
tuneable_blocks = [self.encoder.stem] + [
stage for stage in self.encoder.stages
]
tuneable_blocks = [self.encoder.stem] + [stage for stage in self.encoder.stages]
elif isinstance(self.encoder, timm.models.VisionTransformer):
tuneable_blocks = [self.encoder.patch_embed] + [
block for block in self.encoder.blocks
]
tuneable_blocks = [self.encoder.patch_embed] + [block for block in self.encoder.blocks]
else:
raise ValueError(
f"Encoder architecture not automatically recognised: {type(self.encoder)}"
Expand Down Expand Up @@ -389,9 +375,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx=0):
on_epoch=True,
)

def on_validation_batch_end(
self, outputs: dict, batch, batch_idx: int, dataloader_idx=0
):
def on_validation_batch_end(self, outputs: dict, batch, batch_idx: int, dataloader_idx=0):
self.val_loss_metric(outputs["loss"])
self.log(
"finetuning/val_loss",
Expand Down Expand Up @@ -473,9 +457,7 @@ def __init__(
else:
logging.info("Using multi-class classification")
task = "multiclass"
self.train_acc = tm.Accuracy(
task=task, average="micro", num_classes=num_classes
)
self.train_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)

Expand Down Expand Up @@ -572,9 +554,7 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):
f"Ground Truth: {y_i} \nPrediction: {y_p_i}"
for y_i, y_p_i in zip(y[:n_images], y_pred_softmax[:n_images])
]
self.logger.log_image( # type: ignore
key="val_images", images=images, caption=captions
)
self.logger.log_image(key="val_images", images=images, caption=captions) # type: ignore


class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):
Expand All @@ -595,17 +575,13 @@ class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract):
"""

def __init__(
self, loss: str = "mse", unit_interval: bool = False, **super_kwargs
) -> None:
def __init__(self, loss: str = "mse", unit_interval: bool = False, **super_kwargs) -> None:

super().__init__(**super_kwargs)

self.unit_interval = unit_interval
if self.unit_interval:
logging.info(
"unit_interval=True, using sigmoid activation for finetunng head"
)
logging.info("unit_interval=True, using sigmoid activation for finetunng head")
head_activation = torch.nn.functional.sigmoid
else:
head_activation = None
Expand Down Expand Up @@ -713,9 +689,7 @@ def __init__(self, schema: schemas.Schema, **super_kwargs):
dropout_rate=self.dropout_prob,
)

self.loss = define_model.get_dirichlet_loss_func(
self.schema.question_index_groups
)
self.loss = define_model.get_dirichlet_loss_func(self.schema.question_index_groups)

def upload_images_to_wandb(self, outputs, batch, batch_idx):
raise NotImplementedError
Expand All @@ -724,9 +698,7 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx):


class LinearHead(torch.nn.Module):
def __init__(
self, input_dim: int, output_dim: int, dropout_prob=0.5, activation=None
):
def __init__(self, input_dim: int, output_dim: int, dropout_prob=0.5, activation=None):
"""
Small utility class for a linear head with dropout and optional choice of activation.
Expand Down Expand Up @@ -837,9 +809,9 @@ def dirichlet_loss(y_pred: torch.Tensor, y: torch.Tensor, question_index_groups)
torch.Tensor: Dirichlet-Multinomial loss. Scalar, summing across answers and taking a mean across the batch i.e. sum(axis=1).mean())
"""
# my func uses sklearn convention y, y_pred
return losses.calculate_multiquestion_loss(
y, y_pred, question_index_groups
).mean() * len(question_index_groups)
return losses.calculate_multiquestion_loss(y, y_pred, question_index_groups).mean() * len(
question_index_groups
)


def load_pretrained_zoobot(checkpoint_loc: str) -> torch.nn.Module:
Expand Down

0 comments on commit bb58215

Please sign in to comment.