Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar committed Aug 27, 2024
1 parent 333bc51 commit a922472
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,8 @@ def generate_base_config(
:return: base config object for the given model.
:rtype: dict
"""
base_cfg = utils.generic_base_config(model_name=model_name, model_version=model_version, model_size_in_b=model_size_in_b, cfg=cfg)

base_cfg = utils.generic_base_config(
model_name=model_name, model_version=model_version, model_size_in_b=model_size_in_b, cfg=cfg
)

return base_cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def search_training_config(
# Generate candidate configs.
configs = generate_grid_search_configs(base_cfg, train_cfg, model_size_in_b, model_name)
# Launch candidate configs.
#job_ids = launch_grid_search_configs(base_dir, results_cfgs, model_name, cfg)
# job_ids = launch_grid_search_configs(base_dir, results_cfgs, model_name, cfg)
# Measure and compare throughputs for each config.
#launch_throughput_measure(job_ids, model_name, model_size_in_b, num_nodes, hydra_args, cfg)
# launch_throughput_measure(job_ids, model_name, model_size_in_b, num_nodes, hydra_args, cfg)

return configs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

from nemo.collections.llm.tools.auto_configurator import base_configs

MODULES = {
"llama": "Llama"
}
MODULES = {"llama": "Llama"}


def _calculate_model_size(
vocab_size: int = None,
Expand Down Expand Up @@ -308,7 +307,9 @@ def calculate_model_size_params(
raise Exception("Number of layers not found, config is not possible.")


def generic_base_config(model_name: str = "llama", model_version: int = 2, model_size_in_b: int = 7, cfg: dict = {}) -> dict:
def generic_base_config(
model_name: str = "llama", model_version: int = 2, model_size_in_b: int = 7, cfg: dict = {}
) -> dict:
"""
Generates a base config dictionary from a base config yaml file.
:param omegaconf.dictconfig.DictConfig cfg: hydra-like config object for the HP tool.
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/llm/tools/auto_configurator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ def main():
args = get_args()
configs = search_config(cfg=vars(args))


if __name__ == "__main__":
main()

0 comments on commit a922472

Please sign in to comment.