diff --git a/UltrasoundSegmentation/configs/UsTracking_train_config.yaml b/UltrasoundSegmentation/configs/UsTracking_train_config.yaml index 3c58dd6..afa4faa 100644 --- a/UltrasoundSegmentation/configs/UsTracking_train_config.yaml +++ b/UltrasoundSegmentation/configs/UsTracking_train_config.yaml @@ -6,8 +6,8 @@ lambda_ce: 0.9 # 0.0 for dice, 1.0 for cross-entropy/focal class_weights: [0.2, 0.8] # must be same length as out_channels image_size: 128 in_channels: &in_c !!int 5 -out_channels: !!int 2 -num_epochs: !!int 1 +out_channels: &out_c !!int 2 +num_epochs: !!int 100 batch_size: !!int 128 learning_rate: !!float 0.001 learning_rate_decay_factor: !!float 0.5 @@ -85,3 +85,5 @@ model: name: DualEncoderUNet use_tracking: true params: + in_channels: *in_c + out_channels: *out_c diff --git a/UltrasoundSegmentation/configs/train_config.yaml b/UltrasoundSegmentation/configs/train_config.yaml index fc034c1..632a4f2 100644 --- a/UltrasoundSegmentation/configs/train_config.yaml +++ b/UltrasoundSegmentation/configs/train_config.yaml @@ -5,8 +5,8 @@ loss_function: DiceCE lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal class_weights: [0.25, 0.75] # must be same length as out_channels image_size: 128 -in_channels: !!int 1 -out_channels: !!int 2 +in_channels: &in_c !!int 1 +out_channels: &out_c !!int 2 num_epochs: !!int 100 batch_size: !!int 64 learning_rate: !!float 0.0001 @@ -47,4 +47,7 @@ model: model_path: /path/to/model.py name: CustomNet # name of the model class use_tracking: false # use tracking data in model input - params: # model-specific parameters + params: + in_channels: *in_c + out_channels: *out_c + # other model-specific parameters here diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index a76a4d7..2615fbd 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -302,8 +302,6 @@ def main(args): # instantiate model model_params = model_cfg.get("params", {}) model = getattr(module, model_cfg["name"])( - in_channels=config["in_channels"], - out_channels=config["out_channels"], **(model_params if model_params else {}) ) use_tracking = model_cfg["use_tracking"]