diff --git a/zoobot/pytorch/training/finetune.py b/zoobot/pytorch/training/finetune.py index 0a846cb9..144ce79b 100644 --- a/zoobot/pytorch/training/finetune.py +++ b/zoobot/pytorch/training/finetune.py @@ -421,15 +421,24 @@ class FinetuneableZoobotRegressor(FinetuneableZoobotAbstract): def __init__( self, + unit_interval=False, **super_kwargs) -> None: super().__init__(**super_kwargs) + self.unit_interval = unit_interval + if self.unit_interval: + logging.info('unit_interval=True, using sigmoid activation for finetunng head') + head_activation = torch.nn.functional.sigmoid + else: + head_activation = None + logging.info('Using classification head and cross-entropy loss') self.head = LinearHead( input_dim=self.encoder_dim, output_dim=1, - dropout_prob=self.dropout_prob + dropout_prob=self.dropout_prob, + activation=head_activation ) self.loss = mse_loss # rmse metrics. loss is mse already. @@ -540,17 +549,21 @@ def upload_images_to_wandb(self, outputs, batch, batch_idx): # https://github.com/inigoval/byol/blob/1da1bba7dc5cabe2b47956f9d7c6277decd16cc7/byol_main/networks/models.py#L29 class LinearHead(torch.nn.Module): - def __init__(self, input_dim, output_dim, dropout_prob=0.5): + def __init__(self, input_dim, output_dim, dropout_prob=0.5, activation=None): # input dim is representation dim, output_dim is num classes super(LinearHead, self).__init__() self.output_dim = output_dim self.dropout = torch.nn.Dropout(p=dropout_prob) self.linear = torch.nn.Linear(input_dim, output_dim) + if activation is not None: + self.activation = activation def forward(self, x): # returns logits, as recommended for CrossEntropy loss x = self.dropout(x) x = self.linear(x) + if self.activation is not None: + x = self.activation(x) if self.output_dim == 1: return x.squeeze() else: