From 6f1505b71b985cd9d0a8eb419307ee4d01531cb9 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Thu, 28 Nov 2024 13:50:47 +0000 Subject: [PATCH] changed argument parser to jsonargparse.CLI in the inference scripts --- scripts/pipeline_inference.py | 41 +++++++---------- scripts/single_component_inference.py | 65 +++++++++++---------------- 2 files changed, 42 insertions(+), 64 deletions(-) diff --git a/scripts/pipeline_inference.py b/scripts/pipeline_inference.py index c2f46f5..7242704 100644 --- a/scripts/pipeline_inference.py +++ b/scripts/pipeline_inference.py @@ -1,7 +1,8 @@ -import argparse import json import os +from jsonargparse import CLI + from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation from arc_spice.eval.inference_utils import ResultsGetter, run_inference from arc_spice.utils import open_yaml_path @@ -12,10 +13,17 @@ OUTPUT_DIR = "outputs" -def main(args): +def main(pipeline_config_pth: str, data_config_pth: str): + """ + Run inference on a given pipeline with provided data config + + Args: + pipeline_config_pth: path to pipeline config yaml file + data_config_pth: path to data config yaml file + """ # initialise pipeline - data_config = open_yaml_path(args.data_config) - pipeline_config = open_yaml_path(args.pipeline_config) + data_config = open_yaml_path(data_config_pth) + pipeline_config = open_yaml_path(pipeline_config_pth) data_sets, meta_data = load_multieurlex_for_translation(**data_config) test_loader = data_sets["test"] rtc_variational_pipeline = RTCVariationalPipeline( @@ -29,8 +37,8 @@ def main(args): results_getter=results_getter, ) - data_name = args.data_config.split("/")[-1].split(".")[0] - pipeline_name = args.pipeline_config.split("/")[-1].split(".")[0] + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}" os.makedirs(save_loc, exist_ok=True) @@ -39,23 +47,4 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "From an experiment path generates evaluation plots for every experiment." - ) - ) - parser.add_argument( - "pipeline_config", - type=str, - default=None, - help="Path to pipeline config.", - ) - parser.add_argument( - "data_config", - type=str, - default=None, - help="Path to data config.", - ) - args = parser.parse_args() - - main(args) + CLI(main) diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index 1b0ef0b..a2c4bfc 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -11,10 +11,11 @@ - output/check_callibration/pipeline_name/run_[X]/[OUTPUT FILES HERE] """ -import argparse import json import os +from jsonargparse import CLI + from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation from arc_spice.eval.inference_utils import ResultsGetter, run_inference from arc_spice.utils import open_yaml_path @@ -27,24 +28,39 @@ OUTPUT_DIR = "outputs" -def main(args): +def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): + """ + Run inference on a given pipeline component with provided data config and model key. + + Args: + pipeline_config_pth: path to pipeline config yaml file + data_config_pth: path to data config yaml file + model_key: name of model on which to run inference + """ # initialise pipeline - data_config = open_yaml_path(args.data_config) - pipeline_config = open_yaml_path(args.pipeline_config) + data_config = open_yaml_path(data_config_pth) + pipeline_config = open_yaml_path(pipeline_config_pth) data_sets, meta_data = load_multieurlex_for_translation(**data_config) test_loader = data_sets["test"] - if args.model_key == "ocr": + if model_key == "ocr": rtc_single_component_pipeline = RecognitionVariationalPipeline( model_pars=pipeline_config, data_pars=meta_data ) - if args.model_key == "translator": + elif model_key == "translator": rtc_single_component_pipeline = TranslationVariationalPipeline( model_pars=pipeline_config, data_pars=meta_data ) - if args.model_key == "classifier": + elif model_key == "classifier": rtc_single_component_pipeline = ClassificationVariationalPipeline( model_pars=pipeline_config, data_pars=meta_data ) + else: + error_msg = ( + "model_key should be: 'ocr', 'translator', or 'classifier'." + f" Given: {model_key}" + ) + raise ValueError(error_msg) + results_getter = ResultsGetter(meta_data["n_classes"]) test_results = run_inference( @@ -53,44 +69,17 @@ def main(args): results_getter=results_getter, ) - data_name = args.data_config.split("/")[-1].split(".")[0] - pipeline_name = args.pipeline_config.split("/")[-1].split(".")[0] + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] save_loc = ( f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" f"single_component" ) os.makedirs(save_loc, exist_ok=True) - with open(f"{save_loc}/{args.model_key}.json", "w") as save_file: + with open(f"{save_loc}/{model_key}.json", "w") as save_file: json.dump(test_results, save_file) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "From an experiment path generates evaluation plots for every experiment." - ) - ) - parser.add_argument( - "pipeline_config", - type=str, - default=None, - help="Path to pipeline config.", - ) - parser.add_argument( - "data_config", - type=str, - default=None, - help="Path to data config.", - ) - - parser.add_argument( - "model_key", - type=str, - default=None, - help="Model on which to run inference.", - ) - - args = parser.parse_args() - - main(args) + CLI(main)