Skip to content

Commit

Permalink
Moved all custom model parameters from train.py to train_config.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Nov 27, 2024
1 parent 2fc7107 commit 23dcf18
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
6 changes: 4 additions & 2 deletions UltrasoundSegmentation/configs/UsTracking_train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ lambda_ce: 0.9 # 0.0 for dice, 1.0 for cross-entropy/focal
class_weights: [0.2, 0.8] # must be same length as out_channels
image_size: 128
in_channels: &in_c !!int 5
out_channels: !!int 2
num_epochs: !!int 1
out_channels: &out_c !!int 2
num_epochs: !!int 100
batch_size: !!int 128
learning_rate: !!float 0.001
learning_rate_decay_factor: !!float 0.5
Expand Down Expand Up @@ -85,3 +85,5 @@ model:
name: DualEncoderUNet
use_tracking: true
params:
in_channels: *in_c
out_channels: *out_c
9 changes: 6 additions & 3 deletions UltrasoundSegmentation/configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ loss_function: DiceCE
lambda_ce: 0.5 # 0.0 for dice, 1.0 for cross-entropy/focal
class_weights: [0.25, 0.75] # must be same length as out_channels
image_size: 128
in_channels: !!int 1
out_channels: !!int 2
in_channels: &in_c !!int 1
out_channels: &out_c !!int 2
num_epochs: !!int 100
batch_size: !!int 64
learning_rate: !!float 0.0001
Expand Down Expand Up @@ -47,4 +47,7 @@ model:
model_path: /path/to/model.py
name: CustomNet # name of the model class
use_tracking: false # use tracking data in model input
params: # model-specific parameters
params:
in_channels: *in_c
out_channels: *out_c
# other model-specific parameters here
2 changes: 0 additions & 2 deletions UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ def main(args):
# instantiate model
model_params = model_cfg.get("params", {})
model = getattr(module, model_cfg["name"])(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
**(model_params if model_params else {})
)
use_tracking = model_cfg["use_tracking"]
Expand Down

0 comments on commit 23dcf18

Please sign in to comment.