Skip to content

Commit

Permalink
Merge pull request #6 from aertslab/deeptopic_lstm
Browse files Browse the repository at this point in the history
Deeptopic LSTM model
  • Loading branch information
LukasMahieu authored Jul 15, 2024
2 parents 2c481e5 + a093678 commit 3d3dc4a
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ wandb/
_build
_autosummary
node_modules
.DS_Store
._.DS_Store
3 changes: 2 additions & 1 deletion docs/api/tools/zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ Supply these (or your own) to `tl.Crested(...)` to use them in training.
basenji
chrombpnet
deeptopic_cnn
deeptopic_lstm
simple_convnet
```
```
1 change: 1 addition & 0 deletions src/crested/tl/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._basenji import basenji
from ._chrombpnet import chrombpnet
from ._deeptopic_cnn import deeptopic_cnn
from ._deeptopic_lstm import deeptopic_lstm
from ._simple_convnet import simple_convnet
110 changes: 110 additions & 0 deletions src/crested/tl/zoo/_deeptopic_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Deeptopic LSTM model architecture."""

import keras
import pickle

from crested.tl.zoo.utils import get_output


def deeptopic_lstm(
seq_len: int,
num_classes: int,
filters: int = 300,
first_kernel_size: int = 30,
max_pool_size: int = 15,
max_pool_stride: int = 5,
dense_out: int = 256,
lstm_out: int = 128,
first_activation: str = "relu",
activation: str = "relu",
lstm_do: float = 0.1,
dense_do: float = 0.4,
pre_dense_do: float = 0.2,
motifs_path: str = None,
) -> keras.Model:
"""
Construct a DeepTopicLSTM model. Usually used for topic classification.
Parameters
----------
seq_len
Width of the input region.
num_classes
Number of classes to predict.
filters
Number of filters in the first convolutional layer.
Followed by halving in subsequent layers.
first_kernel_size
Size of the kernel in the first convolutional layer.
max_pool_size
Size of the max pooling kernel.
max_pool_stride
Stride of the max pooling kernel.
dense_out
Number of neurons in the dense layer.
lstm_out
Number of units in the lstm layer.
first_activation
Activation function for the first conv block.
activation
Activation function for subsequent blocks.
lstm_do
Dropout rate for the lstm layer.
dense_do
Dropout rate for the dense layers.
pre_dense_do
Dropout rate before the dense layers.
motifs_path
Path to the motif file to initialize the convolutional weights.
Returns
-------
keras.Model
A Keras model.
"""
inputs = keras.layers.Input(shape=(seq_len, 4), name="sequence")

hidden_layers = [
keras.layers.Convolution1D(
filters=filters,
kernel_size=first_kernel_size,
activation=first_activation,
padding="valid",
kernel_initializer='random_uniform'
),
keras.layers.MaxPooling1D(
pool_size=max_pool_size,
strides=max_pool_stride,
padding='valid'
),
keras.layers.Dropout(pre_dense_do),
keras.layers.TimeDistributed(keras.layers.Dense(lstm_out, activation=activation)),
keras.layers.Bidirectional(keras.layers.LSTM(
lstm_out,
dropout=lstm_do,
recurrent_dropout=lstm_do,
return_sequences=True)
),
keras.layers.Dropout(pre_dense_do),
keras.layers.Flatten(),
keras.layers.Dense(dense_out, activation=activation),
keras.layers.Dropout(dense_do),
keras.layers.Dense(num_classes, activation='sigmoid')
]

outputs = get_output(inputs, hidden_layers)

model = keras.Model(inputs=inputs, outputs=outputs)

if motifs_path is not None:
f = open(motifs_path, "rb")
motif_dict = pickle.load(f)
f.close()
conv_weights = model.layers[2].get_weights()

for i, name in enumerate(motif_dict):
conv_weights[0][:, :, i] = conv_weights[0][:, :, i] * 0.1
conv_weights[0][int((30 - len(motif_dict[name])) / 2):int((30 - len(motif_dict[name])) / 2) + len(motif_dict[name]), :, i] = motif_dict[name]
model.layers[2].set_weights(conv_weights)

return model
24 changes: 24 additions & 0 deletions src/crested/tl/zoo/utils/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"dense_block",
"conv_block",
"activate",
"get_output",
"conv_block_bs",
"dilated_residual",
]
Expand Down Expand Up @@ -205,6 +206,29 @@ def activate(
return current


def get_output(input_layer, hidden_layers):
"""
Pass input layer through hidden layers.
Parameters
----------
input_layer
Input layer.
hidden_layers
Hidden layers.
Returns
-------
tf.Tensor
Output tensor after passing through all hidden layers.
"""
output = input_layer
for hidden_layer in hidden_layers:
output = hidden_layer(output)

return output


def conv_block_bs(
inputs,
filters: int | None = None,
Expand Down

0 comments on commit 3d3dc4a

Please sign in to comment.