diff --git a/nas-search/plot-progress/parse_search_output.py b/nas-search/plot-progress/parse_search_output.py index 045f83d..6862f9b 100644 --- a/nas-search/plot-progress/parse_search_output.py +++ b/nas-search/plot-progress/parse_search_output.py @@ -33,6 +33,7 @@ def encode_single_path_nas_arch(inds, hard=False): for layer_cnt in range(20): inds_row = inds[layer_cnt] + print(inds_row) if inds_row == [0.0, 0.0, 0.0]: idx = 4 # skip elif inds_row == [0.0, 0.0, 1.0]: diff --git a/train-final/models.py b/train-final/models.py index a11ff83..4f5e08b 100644 --- a/train-final/models.py +++ b/train-final/models.py @@ -107,7 +107,7 @@ def encode(self, blocks_args): import parse_netarch -def parse_netarch(parse_lambda_dir, depth_multiplier=None): +def parse_netarch_model(parse_lambda_dir, depth_multiplier=None): """Creates the RNAS found model. No need to hard-code model, it parses the output of previous search @@ -128,7 +128,7 @@ def parse_netarch(parse_lambda_dir, depth_multiplier=None): indicator_values = parse_netarch.parse_indicators_single_path_nas(parse_lambda_dir, tf_size_guidance) network = parse_netarch.encode_single_path_nas_arch(indicator_values) parse_netarch.print_net(network) - blocks_args = parse_netarch.mnasnet_encoder(network) + blocks_args = parse_netarch.convnet_encoder(network) parse_netarch.print_encoded_net(blocks_args) decoder = MnasNetDecoder() @@ -146,7 +146,7 @@ def parse_netarch(parse_lambda_dir, depth_multiplier=None): def build_model(images, model_name, training, override_params=None, - parse_output_dir=None): + parse_search_dir=None): """A helper functiion to creates a ConvNet model and returns predicted logits. Args: @@ -165,8 +165,8 @@ def build_model(images, model_name, training, override_params=None, """ assert isinstance(images, tf.Tensor) if model_name == 'single-path': - assert parse_output_dir is not None - blocks_args, global_params = parse_netarch(parse_output_dir) + assert parse_search_dir is not None + blocks_args, global_params = parse_netarch_model(parse_search_dir) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name)