Skip to content

Commit

Permalink
add output activation option to all models
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMahieu committed Nov 15, 2024
1 parent 4d19481 commit 9df5881
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/crested/tl/zoo/_chrombpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def chrombpnet(
num_filters: int = 512,
filter_size: int = 3,
activation: str = "relu",
output_activation: str = "softplus",
l2: float = 0.00001,
dropout: float = 0.1,
batch_norm: bool = True,
Expand Down Expand Up @@ -50,6 +51,8 @@ def chrombpnet(
Size of the kernel in the dilated convolutional layers.
activation
Activation function in the dilated convolutional layers.
output_activation
Activation function for the output layer.
l2
L2 regularization for the dilated convolutional layers.
dropout
Expand Down Expand Up @@ -126,7 +129,10 @@ def chrombpnet(

x = keras.layers.GlobalAveragePooling1D()(x)
outputs = keras.layers.Dense(
units=num_classes, activation="softplus", use_bias=dense_bias, name="dense_out"
units=num_classes,
activation=output_activation,
use_bias=dense_bias,
name="dense_out",
)(x)

model = keras.Model(inputs=inputs, outputs=outputs)
Expand Down
8 changes: 6 additions & 2 deletions src/crested/tl/zoo/_deeptopic_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def deeptopic_cnn(
dense_out: int = 1024,
first_activation: str = "gelu",
activation: str = "relu",
output_activation: str = "sigmoid",
conv_do: float = 0.15,
normalization: str = "batch",
dense_do: float = 0.5,
Expand Down Expand Up @@ -43,6 +44,8 @@ def deeptopic_cnn(
Activation function for the first conv block.
activation
Activation function for subsequent blocks.
output_activation
Activation function for the output layer.
conv_do
Dropout rate for the convolutional layers.
normalization
Expand Down Expand Up @@ -135,6 +138,7 @@ def deeptopic_cnn(
name_prefix="denseblock",
use_bias=False,
)
logits = keras.layers.Dense(num_classes, activation="linear", use_bias=True)(x)
outputs = keras.layers.Activation("sigmoid")(logits)
outputs = keras.layers.Dense(
num_classes, activation=output_activation, use_bias=True
)(x)
return keras.Model(inputs=inputs, outputs=outputs)
5 changes: 4 additions & 1 deletion src/crested/tl/zoo/_deeptopic_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def deeptopic_lstm(
lstm_out: int = 128,
first_activation: str = "relu",
activation: str = "relu",
output_activation: str = "sigmoid",
lstm_do: float = 0.1,
dense_do: float = 0.4,
pre_dense_do: float = 0.2,
Expand Down Expand Up @@ -49,6 +50,8 @@ def deeptopic_lstm(
Activation function for the first conv block.
activation
Activation function for subsequent blocks.
output_activation
Activation function for the output layer.
lstm_do
Dropout rate for the lstm layer.
dense_do
Expand Down Expand Up @@ -91,7 +94,7 @@ def deeptopic_lstm(
keras.layers.Flatten(),
keras.layers.Dense(dense_out, activation=activation),
keras.layers.Dropout(dense_do),
keras.layers.Dense(num_classes, activation="sigmoid"),
keras.layers.Dense(num_classes, activation=output_activation),
]

outputs = get_output(inputs, hidden_layers)
Expand Down

0 comments on commit 9df5881

Please sign in to comment.