This directory contains all the code that defines models and how to train them. The file model_base.py
defines a base class that all models inherit from. The DeepDeinterlacing model as described in the paper is defined in the file deep_deinterlacing.py
.
During training the model saves checkpoints in its working directory. The model can be initialized with the trained parameters later using the files in the checkpoint directory. To load the model parameters:
- Change the
state_dict_path
property in the model config file to the path of the correspondingmodel_state_dict.pt
file. - Initialize the model with the
_initialize_model_from_config
method with the path to the updated config file.
To define a new model one should create a new file and define a class that inherits from ModelBase. In that file make sure to
- Define a configuration class that inherits from BaseConfig that defines the properties specific to the new model.
- Define a configuration class that inherits from BaseConfig that defines the properties specific to the training procedure of the new model.
- Decorate the class of the new model with the
register_model
decorator as defined inmodel_base.py
. The register model decorator should be provided with a tag that is used to select the model type in config files. The other arguments are to link the model to de model configuration and training configuration classes as defined in step 1, and 2. - Define the model architecture by overwriting the
_initialize_architecture
method.
Warning Do not overwrite the
__init__
method of the model without calling the parent constructor. Define all variables and model structure in the_initialize_architecture
method.
Warning Do not forget to aggregate submodules in a torch.nn.Modulelist to ensure they are registered.
-
Overwrite the
forward
method that defines how data passes through the model -
Overwrite the
_forward_and_compute_loss
method that performs a forward pass and computes the training loss. -
Optional: Overwrite the
_create_callbacks
method to add model-specific callbacks. For more on defining callbacks please refer to the utils README. The method should look something like this:
def _create_callbacks(self):
callbacks = super()._create_callbacks()
callbacks.append(MyNewCallback(...))
return callbacks
- Import the model file in
src/__init__.py