diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index bb30af0..e9aa3db 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -187,7 +187,7 @@ def main(args): train_dataset, batch_size=config["batch_size"], shuffle=config["shuffle"], - num_workers=4, + num_workers=2, generator=g ) val_dataloader = DataLoader( @@ -252,7 +252,7 @@ def main(args): out_channels=config["out_channels"], channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), - num_res_units=2, + num_res_units=config["num_res_units"] if "num_res_units" in config else 2, dropout=dropout_rate ) model = model.to(device=device)