Skip to content

Commit

Permalink
fix train-final naming issue
Browse files Browse the repository at this point in the history
  • Loading branch information
dstamoulis committed Apr 8, 2019
1 parent 01246c5 commit 21c1f3e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions nas-search/plot-progress/parse_search_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def encode_single_path_nas_arch(inds, hard=False):
for layer_cnt in range(20):

inds_row = inds[layer_cnt]
print(inds_row)
if inds_row == [0.0, 0.0, 0.0]:
idx = 4 # skip
elif inds_row == [0.0, 0.0, 1.0]:
Expand Down
10 changes: 5 additions & 5 deletions train-final/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def encode(self, blocks_args):


import parse_netarch
def parse_netarch(parse_lambda_dir, depth_multiplier=None):
def parse_netarch_model(parse_lambda_dir, depth_multiplier=None):
"""Creates the RNAS found model. No need to hard-code
model, it parses the output of previous search
Expand All @@ -128,7 +128,7 @@ def parse_netarch(parse_lambda_dir, depth_multiplier=None):
indicator_values = parse_netarch.parse_indicators_single_path_nas(parse_lambda_dir, tf_size_guidance)
network = parse_netarch.encode_single_path_nas_arch(indicator_values)
parse_netarch.print_net(network)
blocks_args = parse_netarch.mnasnet_encoder(network)
blocks_args = parse_netarch.convnet_encoder(network)
parse_netarch.print_encoded_net(blocks_args)

decoder = MnasNetDecoder()
Expand All @@ -146,7 +146,7 @@ def parse_netarch(parse_lambda_dir, depth_multiplier=None):


def build_model(images, model_name, training, override_params=None,
parse_output_dir=None):
parse_search_dir=None):
"""A helper functiion to creates a ConvNet model and returns predicted logits.
Args:
Expand All @@ -165,8 +165,8 @@ def build_model(images, model_name, training, override_params=None,
"""
assert isinstance(images, tf.Tensor)
if model_name == 'single-path':
assert parse_output_dir is not None
blocks_args, global_params = parse_netarch(parse_output_dir)
assert parse_search_dir is not None
blocks_args, global_params = parse_netarch_model(parse_search_dir)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)

Expand Down

0 comments on commit 21c1f3e

Please sign in to comment.