-
Notifications
You must be signed in to change notification settings - Fork 9
Adding a new model
We allow prediction models to be easily added and integrated into a Pytorch Lightning module. This
incorporates advanced logging and debugging capabilities, as well as
built-in parallelism. Our interface derives from the BaseModule
.
Adding a model consists of three steps:
- Add a model through the existing
MLPredictionWrapper
orDLPredictionWrapper
. - Add a GIN config file to bind hyperparameters.
- Execute YAIB using a simple command.
An example for an RNN model is provided in the repository folder: docs/adding_model
. Putting the RNN.gin
file in configs/prediction_models
and the rnn.py
file into icu_benchmarks/models allows you to run the model fully. We detail particular steps below and describe specifics for each case.
For standard Scikit-Learn type models (e.g., LGBM), one can
simply wrap MLPredictionWrapper
the function with minimal code
overhead. Many ML (and some DL) models can be incorporated this way, requiring minimal code additions. See below.
@gin.configurable
class RFClassifier(MLWrapper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = self.model_args()
@gin.configurable(module="RFClassifier")
def model_args(self, *args, **kwargs):
return RandomForestClassifier(*args, **kwargs)
It is relatively straightforward to add new Pytorch models to YAIB. We first provide a standard RNN-model which needs no extra components. Then, we show the implementation of the Temporal Fusion Transformer model.
The definition of dl models can be done by creating a subclass from the
DLPredictionWrapper
, inherits the standard methods needed for
training dl learning models. Pytorch Lightning significantly reduces the code
overhead.
@gin.configurable
class RNNet(DLPredictionWrapper):
"""Torch standard RNN model"""
def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs):
super().__init__(
input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs
)
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True)
self.logit = nn.Linear(hidden_dim, num_classes)
def init_hidden(self, x):
h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim)
return h0
def forward(self, x):
h0 = self.init_hidden(x)
out, hn = self.rnn(x, h0)
pred = self.logit(out)
return pred
There are two main questions when you want to add a more complex model:
-
Do you want to manually define the model or use an existing library? This might require adapting the
DLPredictionWrapper
. -
Does the model expect the data to be in a certain format? This might require adapting the
PredictionDataset
.
By adapting, we mean creating a new subclass that inherits most functionality to avoid code duplication, is future-proof, and follows good coding practices.
First, you can add modules to models/layers.py
to use them for your model.
class StaticCovariateEncoder(nn.Module):
"""
Network to produce 4 context vectors to enrich static variables
Variable selection Network --> GRNs
"""
def __init__(self, num_static_vars, hidden, dropout):
super().__init__()
self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars)
self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)])
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
variable_ctx, sparse_weights = self.vsn(x)
# Context vectors:
# variable selection context
# enrichment context
# state_c context
# state_h context
cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns]
return cs, ce, ch, cc
Note that we can create modules out of modules as well.
The next step is to use the building blocks defined in layers.py or modules from an existing library to add to the model in models/dl_models.py
. In this In this case, we use the Pytorch-forecasting library (https://github.com/jdb78/pytorch-forecasting):
class TFTpytorch(DLPredictionWrapper):
supported_run_modes = [RunMode.classification, RunMode.regression]
def __init__(self, dataset, hidden, dropout, n_heads, dropout_att, lr, optimizer, num_classes, *args, **kwargs):
super().__init__(lr=lr, optimizer=optimizer, *args, **kwargs)
self.model = TemporalFusionTransformer.from_dataset(
dataset=dataset)
self.logit = nn.Linear(7, num_classes)
def forward(self, x):
out = self.model(x)
pred = self.logit(out["prediction"])
return pred
Some models require an adjusted dataloader to facilitate, for example, explainability methods. In this case, changes need to be made to the data/loader.py
file to ensure the data loader returns the data in the correct format.
This can be done by creating a class that inherits from PredictionDataset and editing the get_item method.
@gin.configurable("PredictionDatasetTFT")
class PredictionDatasetTFT(PredictionDataset):
def __init__(self, *args, ram_cache: bool = True, **kwargs):
super().__init__(*args, ram_cache=True, **kwargs)
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Function to sample from the data split of choice. Used for TFT.
The data needs to be given to the model in the following order
[static categorical, static continuous,known categorical,known continuous, observed categorical, observed continuous,target,id]
Then, you must check models/wrapper.py
, particularly the step_fn method, to ensure the data is correctly transferred to the device.
To define hyperparameters for each model in a standardized manner, we use GIN-config. We need to specify a GIN file to bind the parameters to train and optimize this model from a choice of hyperparameters. Note that we can use modifiers for the optimizer (e.g, Adam optimizer) and ranges that we can specify in rounded brackets "()". Square brackets, "[]", result in a random choice where the variable is uniformly sampled.
# Hyperparameters for TFT model.
# Common settings for DL models
include "configs/prediction_models/common/DLCommon.gin"
# Optimizer params
train_common.model = @TFT
optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)
# Encoder params
model/hyperparameter.class_to_tune = @TFT
model/hyperparameter.encoder_length = 24
model/hyperparameter.hidden = 256
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.dropout = (0.0, 0.4)
model/hyperparameter.dropout_att = (0.0, 0.4)
model/hyperparameter.n_heads =4
model/hyperparameter.example_length=25
After these steps, your model should be trainable with the following command:
icu-benchmarks train \
-d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here
-n mimic_demo \
-t BinaryClassification \ # Insert task name here
-tn Mortality24 \
--log-dir ../yaib_logs/ \
-m TFT \ # Insert model here
-s 2222 \
-l ../yaib_logs/ \
--tune