From fe05f541a95e95431b7ccf81a4e30ea3bbdc455c Mon Sep 17 00:00:00 2001 From: Maciej Szankin Date: Fri, 29 Sep 2023 22:19:30 -0700 Subject: [PATCH] Update OFA Signed-off-by: Maciej Szankin --- .../supernetwork/image_classification/ofa/ofa_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dynast/supernetwork/image_classification/ofa/ofa_interface.py b/dynast/supernetwork/image_classification/ofa/ofa_interface.py index 10052e6..0ecb620 100644 --- a/dynast/supernetwork/image_classification/ofa/ofa_interface.py +++ b/dynast/supernetwork/image_classification/ofa/ofa_interface.py @@ -16,7 +16,7 @@ import csv import uuid from datetime import datetime -from typing import Dict, Tuple +from typing import Dict, Tuple, Union import torch @@ -46,7 +46,7 @@ class OFARunner(Runner): def __init__( self, supernet: str, - dataset_path: str, + dataset_path: Union[None, str] = None, predictors: Dict[str, Predictor] = {}, batch_size: int = 128, eval_batch_size: int = 128, @@ -301,7 +301,6 @@ def eval_subnet(self, x): # Predictor Mode if self.predictor_mode == True: for metric in self.optimization_metrics: - # TODO(macsz) Maybe move [0] to the `estimate_metric`. individual_results[metric] = self.evaluator.estimate_metric( metric, self.manager.onehot_generic(x).reshape(1, -1) )[0]