Skip to content

Commit

Permalink
changed argument parser to jsonargparse.CLI in the inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 28, 2024
1 parent e6eb4c8 commit 6f1505b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 64 deletions.
41 changes: 15 additions & 26 deletions scripts/pipeline_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import json
import os

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
Expand All @@ -12,10 +13,17 @@
OUTPUT_DIR = "outputs"


def main(args):
def main(pipeline_config_pth: str, data_config_pth: str):
"""
Run inference on a given pipeline with provided data config
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
"""
# initialise pipeline
data_config = open_yaml_path(args.data_config)
pipeline_config = open_yaml_path(args.pipeline_config)
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
test_loader = data_sets["test"]
rtc_variational_pipeline = RTCVariationalPipeline(
Expand All @@ -29,8 +37,8 @@ def main(args):
results_getter=results_getter,
)

data_name = args.data_config.split("/")[-1].split(".")[0]
pipeline_name = args.pipeline_config.split("/")[-1].split(".")[0]
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}"
os.makedirs(save_loc, exist_ok=True)

Expand All @@ -39,23 +47,4 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"From an experiment path generates evaluation plots for every experiment."
)
)
parser.add_argument(
"pipeline_config",
type=str,
default=None,
help="Path to pipeline config.",
)
parser.add_argument(
"data_config",
type=str,
default=None,
help="Path to data config.",
)
args = parser.parse_args()

main(args)
CLI(main)
65 changes: 27 additions & 38 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
- output/check_callibration/pipeline_name/run_[X]/[OUTPUT FILES HERE]
"""

import argparse
import json
import os

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
Expand All @@ -27,24 +28,39 @@
OUTPUT_DIR = "outputs"


def main(args):
def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
"""
Run inference on a given pipeline component with provided data config and model key.
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
model_key: name of model on which to run inference
"""
# initialise pipeline
data_config = open_yaml_path(args.data_config)
pipeline_config = open_yaml_path(args.pipeline_config)
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
test_loader = data_sets["test"]
if args.model_key == "ocr":
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
if args.model_key == "translator":
elif model_key == "translator":
rtc_single_component_pipeline = TranslationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
if args.model_key == "classifier":
elif model_key == "classifier":
rtc_single_component_pipeline = ClassificationVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
)
else:
error_msg = (
"model_key should be: 'ocr', 'translator', or 'classifier'."
f" Given: {model_key}"
)
raise ValueError(error_msg)

results_getter = ResultsGetter(meta_data["n_classes"])

test_results = run_inference(
Expand All @@ -53,44 +69,17 @@ def main(args):
results_getter=results_getter,
)

data_name = args.data_config.split("/")[-1].split(".")[0]
pipeline_name = args.pipeline_config.split("/")[-1].split(".")[0]
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"single_component"
)
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/{args.model_key}.json", "w") as save_file:
with open(f"{save_loc}/{model_key}.json", "w") as save_file:
json.dump(test_results, save_file)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"From an experiment path generates evaluation plots for every experiment."
)
)
parser.add_argument(
"pipeline_config",
type=str,
default=None,
help="Path to pipeline config.",
)
parser.add_argument(
"data_config",
type=str,
default=None,
help="Path to data config.",
)

parser.add_argument(
"model_key",
type=str,
default=None,
help="Model on which to run inference.",
)

args = parser.parse_args()

main(args)
CLI(main)

0 comments on commit 6f1505b

Please sign in to comment.