From 340215df173c392590f90f3a7d3384fe96821932 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 29 Nov 2024 17:03:35 +0000 Subject: [PATCH 01/14] added script generation script for running experiments --- config/RTC_configs/roberta-mt5-zero-shot.yaml | 2 +- .../baskerville_pipeline_inference_test.yaml | 13 +++++ scripts/gen_jobscripts.py | 51 +++++++++++++++++++ scripts/pipeline_inference.py | 18 +++++-- scripts/single_component_inference.py | 16 ++++-- slurm_scripts/single_component_inference.sh | 27 ++++++++++ src/arc_spice/config/jobscript_template.sh | 28 ++++++++++ 7 files changed, 146 insertions(+), 9 deletions(-) create mode 100644 config/experiment/baskerville_pipeline_inference_test.yaml create mode 100644 scripts/gen_jobscripts.py create mode 100644 slurm_scripts/single_component_inference.sh create mode 100644 src/arc_spice/config/jobscript_template.sh diff --git a/config/RTC_configs/roberta-mt5-zero-shot.yaml b/config/RTC_configs/roberta-mt5-zero-shot.yaml index 506a6e8..85a2d79 100644 --- a/config/RTC_configs/roberta-mt5-zero-shot.yaml +++ b/config/RTC_configs/roberta-mt5-zero-shot.yaml @@ -1,4 +1,4 @@ -OCR: +ocr: specific_task: "image-to-text" model: "microsoft/trocr-base-handwritten" diff --git a/config/experiment/baskerville_pipeline_inference_test.yaml b/config/experiment/baskerville_pipeline_inference_test.yaml new file mode 100644 index 0000000..fd7232e --- /dev/null +++ b/config/experiment/baskerville_pipeline_inference_test.yaml @@ -0,0 +1,13 @@ +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + +bask: + jobname: "test" + walltime: '0-12:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" diff --git a/scripts/gen_jobscripts.py b/scripts/gen_jobscripts.py new file mode 100644 index 0000000..e83e53d --- /dev/null +++ b/scripts/gen_jobscripts.py @@ -0,0 +1,51 @@ +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader +from jsonargparse import CLI + +from arc_spice.utils import open_yaml_path + +PROJECT_DIR = Path(__file__, "..", "..").resolve() + + +def main(experiment_config_path: str): + """ + _summary_ + + Args: + experiment_config_path: _description_ + """ + experiment_config = open_yaml_path(experiment_config_path) + pipeline_conf_dir = ( + f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml" + ) + data_conf_dir = ( + f"{PROJECT_DIR}/config/data_configs/{experiment_config['data_config']}.yaml" + ) + pipeline_config = open_yaml_path(pipeline_conf_dir) + # Get jinja template + print(PROJECT_DIR / "src" / "arc_spice" / "configs") + environment = Environment( + loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config") + ) + template = environment.get_template("jobscript_template.sh") + for model in pipeline_config: + script_dict: dict = experiment_config["bask"] + script_dict.update( + { + "script_name": ( + "single_component_inference.py " + f"{pipeline_conf_dir} {data_conf_dir} {model}" + ), + "array_number": 0, + "job_name": "test", + } + ) + train_script = template.render(script_dict) + + with open(f"temp/{model}_test.sh", "w") as f: + f.write(train_script) + + +if __name__ == "__main__": + CLI(main) diff --git a/scripts/pipeline_inference.py b/scripts/pipeline_inference.py index 7242704..0700c6e 100644 --- a/scripts/pipeline_inference.py +++ b/scripts/pipeline_inference.py @@ -3,9 +3,9 @@ from jsonargparse import CLI -from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation +from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline from arc_spice.eval.inference_utils import ResultsGetter, run_inference -from arc_spice.utils import open_yaml_path +from arc_spice.utils import open_yaml_path, seed_everything from arc_spice.variational_pipelines.RTC_variational_pipeline import ( RTCVariationalPipeline, ) @@ -13,18 +13,24 @@ OUTPUT_DIR = "outputs" -def main(pipeline_config_pth: str, data_config_pth: str): +def main( + pipeline_config_pth: str, data_config_pth: str, seed: int, experiment_name: str +): """ Run inference on a given pipeline with provided data config Args: pipeline_config_pth: path to pipeline config yaml file data_config_pth: path to data config yaml file + seed: seed for the the inference pass + experiment_name: name of experiment for saving purposes """ + # seed experiment + seed_everything(seed=seed) # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) - data_sets, meta_data = load_multieurlex_for_translation(**data_config) + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) test_loader = data_sets["test"] rtc_variational_pipeline = RTCVariationalPipeline( model_pars=pipeline_config, data_pars=meta_data @@ -39,7 +45,9 @@ def main(pipeline_config_pth: str, data_config_pth: str): data_name = data_config_pth.split("/")[-1].split(".")[0] pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] - save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}" + save_loc = ( + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}" + ) os.makedirs(save_loc, exist_ok=True) with open(f"{save_loc}/full_pipeline.json", "w") as save_file: diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index a2c4bfc..2904fb0 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -18,7 +18,7 @@ from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation from arc_spice.eval.inference_utils import ResultsGetter, run_inference -from arc_spice.utils import open_yaml_path +from arc_spice.utils import open_yaml_path, seed_everything from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( ClassificationVariationalPipeline, RecognitionVariationalPipeline, @@ -28,15 +28,25 @@ OUTPUT_DIR = "outputs" -def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): +def main( + pipeline_config_pth: str, + data_config_pth: str, + seed: int, + experiment_name: str, + model_key: str, +): """ Run inference on a given pipeline component with provided data config and model key. Args: pipeline_config_pth: path to pipeline config yaml file data_config_pth: path to data config yaml file + seed: seed for the the inference pass + experiment_name: name of experiment for saving purposes model_key: name of model on which to run inference """ + # seed experiment + seed_everything(seed=seed) # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) @@ -72,7 +82,7 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str): data_name = data_config_pth.split("/")[-1].split(".")[0] pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] save_loc = ( - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}" f"single_component" ) os.makedirs(save_loc, exist_ok=True) diff --git a/slurm_scripts/single_component_inference.sh b/slurm_scripts/single_component_inference.sh new file mode 100644 index 0000000..a76f23d --- /dev/null +++ b/slurm_scripts/single_component_inference.sh @@ -0,0 +1,27 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name SPICE_variational_RTC +#SBATCH --time 0-12:0:0 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/%j.out +#SBATCH --cpus-per-gpu 18 + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" + +# TODO: script uses relative path to project home so must be run from home, fix +python scripts/single_component_inference.py diff --git a/src/arc_spice/config/jobscript_template.sh b/src/arc_spice/config/jobscript_template.sh new file mode 100644 index 0000000..9a85517 --- /dev/null +++ b/src/arc_spice/config/jobscript_template.sh @@ -0,0 +1,28 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name {{ job_name }} +#SBATCH --time {{ walltime }} +#SBATCH --nodes {{ node_number }} +#SBATCH --gpus {{ gpu_number }} +#SBATCH --output ./slurm_logs/{{ job_name }}-%j.out +#SBATCH --array=0-{{ array_number }} + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="{{ hf_cache_dir }}" + +# TODO: script uses relative path to project home so must be run from home, fix +python {{ script_name }} From 52cd7b608e5d43ee74ed0da620a4f4283b675ab0 Mon Sep 17 00:00:00 2001 From: tbc Date: Fri, 29 Nov 2024 20:14:52 +0000 Subject: [PATCH 02/14] fixes for inference pipeline --- scripts/gen_jobscripts.py | 11 ++++--- scripts/single_component_inference.py | 4 +-- slurm_scripts/classifier_test.sh | 29 +++++++++++++++++++ slurm_scripts/classifier_test_short.sh | 29 +++++++++++++++++++ slurm_scripts/ocr_test.sh | 29 +++++++++++++++++++ slurm_scripts/translator_test.sh | 29 +++++++++++++++++++ src/arc_spice/config/jobscript_template.sh | 3 +- src/arc_spice/data/multieurlex_utils.py | 2 +- .../RTC_single_component_pipeline.py | 2 +- 9 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 slurm_scripts/classifier_test.sh create mode 100644 slurm_scripts/classifier_test_short.sh create mode 100644 slurm_scripts/ocr_test.sh create mode 100644 slurm_scripts/translator_test.sh diff --git a/scripts/gen_jobscripts.py b/scripts/gen_jobscripts.py index e83e53d..a81ab9f 100644 --- a/scripts/gen_jobscripts.py +++ b/scripts/gen_jobscripts.py @@ -15,6 +15,7 @@ def main(experiment_config_path: str): Args: experiment_config_path: _description_ """ + experiment_name = experiment_config_path.split('/')[-1].split('.')[0] experiment_config = open_yaml_path(experiment_config_path) pipeline_conf_dir = ( f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml" @@ -31,19 +32,21 @@ def main(experiment_config_path: str): template = environment.get_template("jobscript_template.sh") for model in pipeline_config: script_dict: dict = experiment_config["bask"] + seed = experiment_config["seed"][0] script_dict.update( { "script_name": ( - "single_component_inference.py " - f"{pipeline_conf_dir} {data_conf_dir} {model}" + "scripts/single_component_inference.py " + f"{pipeline_conf_dir} {data_conf_dir} {seed} {experiment_name} {model}" ), "array_number": 0, - "job_name": "test", + "job_name": f"{experiment_name}_{model}", + "seed": seed } ) train_script = template.render(script_dict) - with open(f"temp/{model}_test.sh", "w") as f: + with open(f"slurm_scripts/{model}_test.sh", "w") as f: f.write(train_script) diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index 2904fb0..93be612 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -16,7 +16,7 @@ from jsonargparse import CLI -from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation +from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline from arc_spice.eval.inference_utils import ResultsGetter, run_inference from arc_spice.utils import open_yaml_path, seed_everything from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( @@ -50,7 +50,7 @@ def main( # initialise pipeline data_config = open_yaml_path(data_config_pth) pipeline_config = open_yaml_path(pipeline_config_pth) - data_sets, meta_data = load_multieurlex_for_translation(**data_config) + data_sets, meta_data = load_multieurlex_for_pipeline(**data_config) test_loader = data_sets["test"] if model_key == "ocr": rtc_single_component_pipeline = RecognitionVariationalPipeline( diff --git a/slurm_scripts/classifier_test.sh b/slurm_scripts/classifier_test.sh new file mode 100644 index 0000000..e8351b1 --- /dev/null +++ b/slurm_scripts/classifier_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name baskerville_pipeline_inference_test_classifier +#SBATCH --time 0-12:0:0 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_classifier-%j.out +#SBATCH --array=0-0 +#SBATCH --cpus-per-gpu 18 + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" + +# TODO: script uses relative path to project home so must be run from home, fix +python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test classifier \ No newline at end of file diff --git a/slurm_scripts/classifier_test_short.sh b/slurm_scripts/classifier_test_short.sh new file mode 100644 index 0000000..a59b8a7 --- /dev/null +++ b/slurm_scripts/classifier_test_short.sh @@ -0,0 +1,29 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name baskerville_pipeline_inference_test_classifier +#SBATCH --time 0-0:12:0 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_classifier-%j.out +#SBATCH --array=0-0 +#SBATCH --cpus-per-gpu 18 + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" + +# TODO: script uses relative path to project home so must be run from home, fix +python single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test classifier \ No newline at end of file diff --git a/slurm_scripts/ocr_test.sh b/slurm_scripts/ocr_test.sh new file mode 100644 index 0000000..0146785 --- /dev/null +++ b/slurm_scripts/ocr_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name baskerville_pipeline_inference_test_ocr +#SBATCH --time 0-12:0:0 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_ocr-%j.out +#SBATCH --array=0-0 +#SBATCH --cpus-per-gpu 18 + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" + +# TODO: script uses relative path to project home so must be run from home, fix +python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test ocr \ No newline at end of file diff --git a/slurm_scripts/translator_test.sh b/slurm_scripts/translator_test.sh new file mode 100644 index 0000000..bd518c2 --- /dev/null +++ b/slurm_scripts/translator_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash +#SBATCH --account vjgo8416-spice +#SBATCH --qos turing +#SBATCH --job-name baskerville_pipeline_inference_test_translator +#SBATCH --time 0-36:0:0 +#SBATCH --nodes 1 +#SBATCH --gpus 1 +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_translator-%j.out +#SBATCH --array=0-0 +#SBATCH --cpus-per-gpu 18 + + +# Load required modules here +module purge +module load baskerville +module load bask-apps/live/live +module load Python/3.10.8-GCCcore-12.2.0 + + +# change working directory +cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ + +source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate + +# change huggingface cache to be in project dir rather than user home +export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" + +# TODO: script uses relative path to project home so must be run from home, fix +python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test translator \ No newline at end of file diff --git a/src/arc_spice/config/jobscript_template.sh b/src/arc_spice/config/jobscript_template.sh index 9a85517..6583d2b 100644 --- a/src/arc_spice/config/jobscript_template.sh +++ b/src/arc_spice/config/jobscript_template.sh @@ -5,8 +5,9 @@ #SBATCH --time {{ walltime }} #SBATCH --nodes {{ node_number }} #SBATCH --gpus {{ gpu_number }} -#SBATCH --output ./slurm_logs/{{ job_name }}-%j.out +#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out #SBATCH --array=0-{{ array_number }} +#SBATCH --cpus-per-gpu 18 # Load required modules here diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index d46da5e..bbc82d3 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -6,7 +6,7 @@ from datasets.formatting.formatting import LazyRow from PIL import Image from torch.nn.functional import one_hot -from trdg.generators import GeneratorFromStrings +# from trdg.generators import GeneratorFromStrings # For identifying where the adopted decisions begin ARTICLE_1_MARKERS = { diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 9a13646..b1c348b 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -120,7 +120,7 @@ def __init__( self, model_pars: dict[str, dict[str, str]], n_variational_runs=5, - translation_batch_size=8, + translation_batch_size=4, **kwargs, ): self.set_device() From 279c2ec1f45fa447df853012bb496ff33317faf2 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 29 Nov 2024 20:45:50 +0000 Subject: [PATCH 03/14] fixed error functions and added appropriate tests --- scripts/gen_jobscripts.py | 7 +++-- src/arc_spice/data/multieurlex_utils.py | 2 +- src/arc_spice/eval/classification_error.py | 7 +++++ src/arc_spice/eval/inference_utils.py | 5 +-- .../RTC_single_component_pipeline.py | 4 +-- .../RTC_variational_pipeline.py | 4 +-- tests/test_inference.py | 31 +++++++++++++++++++ 7 files changed, 50 insertions(+), 10 deletions(-) diff --git a/scripts/gen_jobscripts.py b/scripts/gen_jobscripts.py index a81ab9f..f96a764 100644 --- a/scripts/gen_jobscripts.py +++ b/scripts/gen_jobscripts.py @@ -15,7 +15,7 @@ def main(experiment_config_path: str): Args: experiment_config_path: _description_ """ - experiment_name = experiment_config_path.split('/')[-1].split('.')[0] + experiment_name = experiment_config_path.split("/")[-1].split(".")[0] experiment_config = open_yaml_path(experiment_config_path) pipeline_conf_dir = ( f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml" @@ -37,11 +37,12 @@ def main(experiment_config_path: str): { "script_name": ( "scripts/single_component_inference.py " - f"{pipeline_conf_dir} {data_conf_dir} {seed} {experiment_name} {model}" + f"{pipeline_conf_dir} {data_conf_dir} {seed}" + f" {experiment_name} {model}" ), "array_number": 0, "job_name": f"{experiment_name}_{model}", - "seed": seed + "seed": seed, } ) train_script = template.render(script_dict) diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index bbc82d3..d46da5e 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -6,7 +6,7 @@ from datasets.formatting.formatting import LazyRow from PIL import Image from torch.nn.functional import one_hot -# from trdg.generators import GeneratorFromStrings +from trdg.generators import GeneratorFromStrings # For identifying where the adopted decisions begin ARTICLE_1_MARKERS = { diff --git a/src/arc_spice/eval/classification_error.py b/src/arc_spice/eval/classification_error.py index 2e8f36c..71292d3 100644 --- a/src/arc_spice/eval/classification_error.py +++ b/src/arc_spice/eval/classification_error.py @@ -1,4 +1,7 @@ +import math + import torch +from sklearn.metrics import zero_one_loss def aggregate_score(probs: torch.Tensor) -> torch.Tensor: @@ -8,6 +11,10 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor: return 1 - torch.mean(distance) +def zero_one_loss_ceil(y_target, y_pred): + return math.ceil(zero_one_loss(y_target, y_pred, normalize=True)) + + def MC_dropout_scores( variational_probs: list[float], epsilon: float = 1e-14 ) -> dict[str, torch.Tensor]: diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 0794fa2..7cd64ed 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -3,11 +3,12 @@ from typing import Any import torch -from sklearn.metrics import hamming_loss, zero_one_loss +from sklearn.metrics import hamming_loss from torch.utils.data import DataLoader from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot +from arc_spice.eval.classification_error import zero_one_loss_ceil from arc_spice.eval.translation_error import get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, @@ -112,7 +113,7 @@ def classification_results( preds = torch.round(mean_scores).tolist() labels = self.multihot(test_row["labels"]) hamming_acc = hamming_loss(y_pred=preds, y_true=labels) - zero_one_acc = zero_one_loss(y_pred=preds, y_true=labels) + zero_one_acc = zero_one_loss_ceil(y_pred=preds, y_target=labels) return ClassificationResults( mean_scores=mean_scores.detach().tolist(), diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index b1c348b..0be0025 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -98,8 +98,8 @@ def __init__( ): self.set_device() self.ocr = pipeline( - task=model_pars["OCR"]["specific_task"], - model=model_pars["OCR"]["model"], + task=model_pars["ocr"]["specific_task"], + model=model_pars["ocr"]["model"], device=self.device, **kwargs, ) diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index eba0c16..0196e6e 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -41,8 +41,8 @@ def __init__( super().__init__(n_variational_runs, translation_batch_size) # defining the pipeline objects self.ocr = pipeline( - task=model_pars["OCR"]["specific_task"], - model=model_pars["OCR"]["model"], + task=model_pars["ocr"]["specific_task"], + model=model_pars["ocr"]["model"], device=self.device, ) self.translator = pipeline( diff --git a/tests/test_inference.py b/tests/test_inference.py index a7e6b08..19fa240 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -2,7 +2,9 @@ from unittest.mock import MagicMock, patch import pytest +from sklearn.metrics import hamming_loss +from arc_spice.eval.classification_error import zero_one_loss_ceil from arc_spice.utils import open_yaml_path from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( ClassificationVariationalPipeline, @@ -40,6 +42,35 @@ def dummy_metadata(): } +def test_errors(): + dummy_target = [0, 1, 0, 1, 0] + dummy_middle_output = [1, 1, 0, 1, 0] + + assert hamming_loss(dummy_target, dummy_middle_output) == pytest.approx( + 0.2, abs=1e-5 + ) + assert zero_one_loss_ceil(dummy_target, dummy_middle_output) == pytest.approx( + 1.0, abs=1e-5 + ) + + dummy_correct_output = [0, 1, 0, 1, 0] + + assert hamming_loss(dummy_target, dummy_correct_output) == pytest.approx( + 0.0, abs=1e-5 + ) + assert zero_one_loss_ceil(dummy_target, dummy_correct_output) == pytest.approx( + 0.0, abs=1e-5 + ) + + dummy_incorrect_output = [1, 0, 1, 0, 1] + assert hamming_loss(dummy_target, dummy_incorrect_output) == pytest.approx( + 1.0, abs=1e-5 + ) + assert zero_one_loss_ceil(dummy_target, dummy_incorrect_output) == pytest.approx( + 1.0, abs=1e-5 + ) + + def test_pipeline_inputs(dummy_data, dummy_metadata): pipeline_config = open_yaml_path(PIPELINE_PATH) From 11743e971c7a421abf9d9e8e2cb9de3bffb16fa8 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Mon, 2 Dec 2024 12:52:15 +0000 Subject: [PATCH 04/14] changes to the forward methods so that logits are not stored outside of the custom pipeline method --- scripts/single_component_inference.py | 2 +- .../variational_pipelines/RTC_variational_pipeline.py | 6 +++++- src/arc_spice/variational_pipelines/utils.py | 9 +++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index 93be612..8446b0d 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -82,7 +82,7 @@ def main( data_name = data_config_pth.split("/")[-1].split(".")[0] pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] save_loc = ( - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}" + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}/" f"single_component" ) os.makedirs(save_loc, exist_ok=True) diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index 0196e6e..106b581 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -2,6 +2,7 @@ from typing import Any import torch +from torch.nn.functional import softmax from transformers import TranslationPipeline, pipeline from arc_spice.variational_pipelines.utils import ( @@ -180,5 +181,8 @@ def _forward(self, model_inputs, **generate_kwargs): # logits are a tuple of length output_ids[-1]-1 # each element is a tensor of shape (batch_size, vocab_size) logits = torch.stack(out["logits"], dim=1) + # get softmax of the logits to get token probabilities + softmax_logits = softmax(logits, dim=-1) + max_token_scores = torch.max(softmax_logits, dim=-1).values - return {"output_ids": output_ids, "logits": logits} + return {"output_ids": output_ids, "scores": max_token_scores} diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index c9e2692..1de7427 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -259,14 +259,11 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]: ] # join these to create the full translation full_translation = ("").join(sentence_translations) - # get softmax of the logits to get token probabilities - softmax_logits = softmax(translator_outputs[0]["raw_outputs"]["logits"], dim=-1) - max_token_scores = torch.max(softmax_logits, dim=-1).values.squeeze(dim=0) # record the output and token probabilities confidence_metrics = [ { "outputs": translator_output["translation_text"], - "probs": max_token_scores, + "probs": translator_output["raw_outputs"]["scores"], } for translator_output in translator_outputs ] @@ -373,7 +370,8 @@ def sentence_density( # TODO vectorize # calculate conditional probabilities take power first to avoid NaN - for var_index, var_score in enumerate(var_scores): + for var_index, var_score_out in enumerate(var_scores): + var_score = var_score_out.squeeze() cond_probs[var_index] = torch.prod( torch.pow(var_score, 1 / len(var_score)), dim=-1 ) @@ -381,7 +379,6 @@ def sentence_density( semantic_density = (1 / torch.sum(cond_probs)) * torch.sum( torch.mul(cond_probs, kernel_funcs) ) - return semantic_density.item(), sequence_length def translation_semantic_density( From 439ae28f1fa15c995097ab39a69de083319801df Mon Sep 17 00:00:00 2001 From: tbc Date: Mon, 2 Dec 2024 14:39:18 +0000 Subject: [PATCH 05/14] adding minor inference argument changes --- slurm_scripts/translator_test.sh | 6 +++--- src/arc_spice/data/multieurlex_utils.py | 2 +- .../variational_pipelines/RTC_variational_pipeline.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/slurm_scripts/translator_test.sh b/slurm_scripts/translator_test.sh index bd518c2..73dc7fe 100644 --- a/slurm_scripts/translator_test.sh +++ b/slurm_scripts/translator_test.sh @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --account vjgo8416-spice #SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_translator -#SBATCH --time 0-36:0:0 +#SBATCH --job-name baskerville_pipeline_inference_test_translator_larger_bs +#SBATCH --time 0-12:0:0 #SBATCH --nodes 1 #SBATCH --gpus 1 #SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_translator-%j.out @@ -26,4 +26,4 @@ source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" # TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test translator \ No newline at end of file +python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test_large_bs translator \ No newline at end of file diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index d46da5e..bbc82d3 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -6,7 +6,7 @@ from datasets.formatting.formatting import LazyRow from PIL import Image from torch.nn.functional import one_hot -from trdg.generators import GeneratorFromStrings +# from trdg.generators import GeneratorFromStrings # For identifying where the adopted decisions begin ARTICLE_1_MARKERS = { diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index 106b581..15bf990 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -37,7 +37,7 @@ def __init__( model_pars: dict[str, dict[str, str]], data_pars: dict[str, Any], n_variational_runs=5, - translation_batch_size=8, + translation_batch_size=16, ) -> None: super().__init__(n_variational_runs, translation_batch_size) # defining the pipeline objects From 4ebea1201c52b2288f1648a324c09ae132c3aa14 Mon Sep 17 00:00:00 2001 From: tbc Date: Tue, 3 Dec 2024 11:18:03 +0000 Subject: [PATCH 06/14] added a drop length to the dataloader, this will prevent overly-long inputs from being used during inference --- src/arc_spice/data/multieurlex_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/arc_spice/data/multieurlex_utils.py b/src/arc_spice/data/multieurlex_utils.py index bbc82d3..871217d 100644 --- a/src/arc_spice/data/multieurlex_utils.py +++ b/src/arc_spice/data/multieurlex_utils.py @@ -6,7 +6,7 @@ from datasets.formatting.formatting import LazyRow from PIL import Image from torch.nn.functional import one_hot -# from trdg.generators import GeneratorFromStrings +from trdg.generators import GeneratorFromStrings # For identifying where the adopted decisions begin ARTICLE_1_MARKERS = { @@ -133,6 +133,7 @@ def load_multieurlex( level: int, languages: list[str], drop_empty: bool = True, + drop_length: int | None = None, split: str | None = None, ) -> tuple[datasets.DatasetDict, dict[str, Any]]: """ @@ -188,6 +189,11 @@ def load_multieurlex( lambda x: all(x is not None for x in x["text"].values()) ) + if drop_length: + dataset_dict = dataset_dict.filter( + lambda x: len(x["text"][languages[0]]) <= drop_length + ) + # return datasets and metadata return dataset_dict, metadata @@ -197,11 +203,16 @@ def load_multieurlex_for_pipeline( level: int, lang_pair: dict[str, str], drop_empty: bool = True, + drop_length: int | None = None, load_ocr_data: bool = False, ) -> tuple[datasets.DatasetDict, dict[str, Any]]: langs = [lang_pair["source"], lang_pair["target"]] dataset_dict, meta_data = load_multieurlex( - data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty + data_dir=data_dir, + level=level, + languages=langs, + drop_empty=drop_empty, + drop_length=drop_length, ) # instantiate the preprocessor preprocesser = TranslationPreProcesser(lang_pair) From a9760cdbc5f58d12f75b68319748a25a2921098c Mon Sep 17 00:00:00 2001 From: tbc Date: Tue, 3 Dec 2024 11:56:57 +0000 Subject: [PATCH 07/14] Added a drop_length parameter, if passed it specifies a maximum character length of the first language passed to the dataset loader. --- config/data_configs/l1_fr_to_en.yaml | 2 ++ config/experiment/baskerville_pipeline_inference_test.yaml | 2 +- slurm_scripts/translator_test.sh | 6 +++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config/data_configs/l1_fr_to_en.yaml b/config/data_configs/l1_fr_to_en.yaml index 1a3a373..58e12f1 100644 --- a/config/data_configs/l1_fr_to_en.yaml +++ b/config/data_configs/l1_fr_to_en.yaml @@ -5,3 +5,5 @@ level: 1 lang_pair: source: "fr" target: "en" + +drop_length: 1000 diff --git a/config/experiment/baskerville_pipeline_inference_test.yaml b/config/experiment/baskerville_pipeline_inference_test.yaml index fd7232e..281f1d7 100644 --- a/config/experiment/baskerville_pipeline_inference_test.yaml +++ b/config/experiment/baskerville_pipeline_inference_test.yaml @@ -6,7 +6,7 @@ seed: - 42 bask: - jobname: "test" + jobname: "shortened_input_test" walltime: '0-12:0:0' gpu_number: 1 node_number: 1 diff --git a/slurm_scripts/translator_test.sh b/slurm_scripts/translator_test.sh index 73dc7fe..506603a 100644 --- a/slurm_scripts/translator_test.sh +++ b/slurm_scripts/translator_test.sh @@ -1,8 +1,8 @@ #!/bin/bash #SBATCH --account vjgo8416-spice #SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_translator_larger_bs -#SBATCH --time 0-12:0:0 +#SBATCH --job-name baskerville_pipeline_inference_test_translator +#SBATCH --time 0-24:0:0 #SBATCH --nodes 1 #SBATCH --gpus 1 #SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_translator-%j.out @@ -26,4 +26,4 @@ source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" # TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test_large_bs translator \ No newline at end of file +python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test translator \ No newline at end of file From 12cd2975788bf94c3e93c4ed5f30a2ba02479618 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Tue, 3 Dec 2024 14:56:04 +0000 Subject: [PATCH 08/14] dropped zero-one accuracy and added other metrics to the inference results output --- src/arc_spice/eval/inference_utils.py | 17 +++++++++++------ src/arc_spice/eval/translation_error.py | 5 +++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 7cd64ed..27513cf 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -8,8 +8,7 @@ from tqdm import tqdm from arc_spice.data.multieurlex_utils import MultiHot -from arc_spice.eval.classification_error import zero_one_loss_ceil -from arc_spice.eval.translation_error import get_comet_model +from arc_spice.eval.translation_error import conditional_probability, get_comet_model from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( RTCSingleComponentPipeline, ) @@ -21,9 +20,9 @@ ClassificationResults = namedtuple( "ClassificationResults", [ + "clean_scores", "mean_scores", "hamming_accuracy", - "zero_one_accuracy", "mean_predicted_entropy", ], ) @@ -31,6 +30,7 @@ "TranslationResults", [ "full_output", + "clean_conditional_probability", "comet_score", "weighted_semantic_density", ], @@ -79,6 +79,10 @@ def translation_results( source_text = test_row["target_text"] target_text = test_row["target_text"] clean_translation = clean_output["translation"]["full_output"] + probs: list[torch.Tensor] = clean_output["translation"]["probs"] + clean_cond_prob = [ + conditional_probability(prob.squeeze()).detach().tolist() for prob in probs + ] # define error model inputs comet_inp = [ @@ -97,6 +101,7 @@ def translation_results( return TranslationResults( comet_score=comet_output["scores"][0], full_output=clean_translation, + clean_conditional_probability=clean_cond_prob, weighted_semantic_density=var_output["translation"][ "weighted_semantic_density" ], @@ -106,19 +111,19 @@ def classification_results( self, test_row: dict[str, Any], var_output: dict[str, dict], - **kwargs, + clean_output: dict[str, dict], ): # ### CLASSIFICATION ### mean_scores: torch.Tensor = var_output["classification"]["mean_scores"] + clean_scores: torch.Tensor = clean_output["classification"]["scores"] preds = torch.round(mean_scores).tolist() labels = self.multihot(test_row["labels"]) hamming_acc = hamming_loss(y_pred=preds, y_true=labels) - zero_one_acc = zero_one_loss_ceil(y_pred=preds, y_target=labels) return ClassificationResults( mean_scores=mean_scores.detach().tolist(), + clean_scores=clean_scores, hamming_accuracy=hamming_acc, - zero_one_accuracy=zero_one_acc, mean_predicted_entropy=torch.mean( var_output["classification"]["predicted_entropy"] ).item(), diff --git a/src/arc_spice/eval/translation_error.py b/src/arc_spice/eval/translation_error.py index 510b157..e15a5f4 100644 --- a/src/arc_spice/eval/translation_error.py +++ b/src/arc_spice/eval/translation_error.py @@ -1,3 +1,4 @@ +import torch from comet import download_model, load_from_checkpoint from torcheval.metrics.functional import bleu_score @@ -10,3 +11,7 @@ def get_comet_model(model_path="Unbabel/wmt22-comet-da"): # Load the model checkpoint: comet_model_pth = download_model(model=model_path) return load_from_checkpoint(comet_model_pth) + + +def conditional_probability(prob_scores: torch.Tensor): + return torch.prod(torch.pow(prob_scores, 1 / len(prob_scores)), dim=-1) From a22238b5982e4c04a604ce671a538c52bb788819 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Tue, 3 Dec 2024 15:45:52 +0000 Subject: [PATCH 09/14] updated gen_configs script to create experiments in slurm_scripts folder, now save multiple runs under subdirectories named via the seed --- .gitignore | 1 + .../experiment/full_experiment_zero_shot.yaml | 15 ++++++ scripts/gen_jobscripts.py | 51 ++++++++++++++----- scripts/pipeline_inference.py | 15 +++--- scripts/single_component_inference.py | 16 +++--- src/arc_spice/config/jobscript_template.sh | 1 - 6 files changed, 71 insertions(+), 28 deletions(-) create mode 100644 config/experiment/full_experiment_zero_shot.yaml diff --git a/.gitignore b/.gitignore index 9d6dc80..75a9bb2 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ Thumbs.db slurm_scripts/slurm_logs* +slurm_scripts/experiments* # other temp .vscode diff --git a/config/experiment/full_experiment_zero_shot.yaml b/config/experiment/full_experiment_zero_shot.yaml new file mode 100644 index 0000000..db95e53 --- /dev/null +++ b/config/experiment/full_experiment_zero_shot.yaml @@ -0,0 +1,15 @@ +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + - 43 + - 44 + +bask: + jobname: "full_experiment_with_zero_shot" + walltime: '0-24:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" diff --git a/scripts/gen_jobscripts.py b/scripts/gen_jobscripts.py index f96a764..7420743 100644 --- a/scripts/gen_jobscripts.py +++ b/scripts/gen_jobscripts.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from jinja2 import Environment, FileSystemLoader @@ -25,30 +26,56 @@ def main(experiment_config_path: str): ) pipeline_config = open_yaml_path(pipeline_conf_dir) # Get jinja template - print(PROJECT_DIR / "src" / "arc_spice" / "configs") environment = Environment( loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config") ) template = environment.get_template("jobscript_template.sh") - for model in pipeline_config: - script_dict: dict = experiment_config["bask"] - seed = experiment_config["seed"][0] - script_dict.update( + # We don't want to overwrite results + + for index, seed in enumerate(experiment_config["seed"]): + os.makedirs( + f"slurm_scripts/experiments/{experiment_name}/run_{index}", exist_ok=False + ) + for model in pipeline_config: + model_script_dict: dict = experiment_config["bask"] + model_script_dict.update( + { + "script_name": ( + "scripts/single_component_inference.py " + f"{pipeline_conf_dir} {data_conf_dir} {seed}" + f" {experiment_name} {model}" + ), + "job_name": f"{experiment_name}_{model}", + "seed": seed, + } + ) + model_train_script = template.render(model_script_dict) + + with open( + f"slurm_scripts/experiments/{experiment_name}/run_{index}/{model}.sh", + "w", + ) as f: + f.write(model_train_script) + + pipeline_script_dict: dict = experiment_config["bask"] + pipeline_script_dict.update( { "script_name": ( - "scripts/single_component_inference.py " + "scripts/pipeline_inference.py " f"{pipeline_conf_dir} {data_conf_dir} {seed}" - f" {experiment_name} {model}" + f" {experiment_name}" ), - "array_number": 0, - "job_name": f"{experiment_name}_{model}", + "job_name": f"{experiment_name}_full_pipeline", "seed": seed, } ) - train_script = template.render(script_dict) + pipeline_train_script = template.render(pipeline_script_dict) - with open(f"slurm_scripts/{model}_test.sh", "w") as f: - f.write(train_script) + with open( + f"slurm_scripts/experiments/{experiment_name}/run_{index}/full_pipeline.sh", + "w", + ) as f: + f.write(pipeline_train_script) if __name__ == "__main__": diff --git a/scripts/pipeline_inference.py b/scripts/pipeline_inference.py index 0700c6e..e001b78 100644 --- a/scripts/pipeline_inference.py +++ b/scripts/pipeline_inference.py @@ -25,6 +25,14 @@ def main( seed: seed for the the inference pass experiment_name: name of experiment for saving purposes """ + # create save directory -> fail if already exists + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] + save_loc = ( + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" + f"{experiment_name}/seed_{seed}/" + ) + os.makedirs(save_loc, exist_ok=False) # seed experiment seed_everything(seed=seed) # initialise pipeline @@ -43,13 +51,6 @@ def main( results_getter=results_getter, ) - data_name = data_config_pth.split("/")[-1].split(".")[0] - pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] - save_loc = ( - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}" - ) - os.makedirs(save_loc, exist_ok=True) - with open(f"{save_loc}/full_pipeline.json", "w") as save_file: json.dump(test_results, save_file) diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index 8446b0d..ee14fe6 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -45,6 +45,14 @@ def main( experiment_name: name of experiment for saving purposes model_key: name of model on which to run inference """ + # create save directory -> fail if already exists + data_name = data_config_pth.split("/")[-1].split(".")[0] + pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] + save_loc = ( + f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" + f"{experiment_name}/seed_{seed}/" + ) + os.makedirs(save_loc, exist_ok=False) # seed experiment seed_everything(seed=seed) # initialise pipeline @@ -79,14 +87,6 @@ def main( results_getter=results_getter, ) - data_name = data_config_pth.split("/")[-1].split(".")[0] - pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0] - save_loc = ( - f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/{experiment_name}/" - f"single_component" - ) - os.makedirs(save_loc, exist_ok=True) - with open(f"{save_loc}/{model_key}.json", "w") as save_file: json.dump(test_results, save_file) diff --git a/src/arc_spice/config/jobscript_template.sh b/src/arc_spice/config/jobscript_template.sh index 6583d2b..9df076e 100644 --- a/src/arc_spice/config/jobscript_template.sh +++ b/src/arc_spice/config/jobscript_template.sh @@ -6,7 +6,6 @@ #SBATCH --nodes {{ node_number }} #SBATCH --gpus {{ gpu_number }} #SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out -#SBATCH --array=0-{{ array_number }} #SBATCH --cpus-per-gpu 18 From 0b7e2245686c5af6376bc6a6e43aac08e98c40c2 Mon Sep 17 00:00:00 2001 From: tbc Date: Tue, 3 Dec 2024 16:07:36 +0000 Subject: [PATCH 10/14] moving prototype slurm scripts to the experiments subdirectory --- slurm_scripts/classifier_test.sh | 29 -------------------------- slurm_scripts/classifier_test_short.sh | 29 -------------------------- slurm_scripts/ocr_test.sh | 29 -------------------------- slurm_scripts/translator_test.sh | 29 -------------------------- 4 files changed, 116 deletions(-) delete mode 100644 slurm_scripts/classifier_test.sh delete mode 100644 slurm_scripts/classifier_test_short.sh delete mode 100644 slurm_scripts/ocr_test.sh delete mode 100644 slurm_scripts/translator_test.sh diff --git a/slurm_scripts/classifier_test.sh b/slurm_scripts/classifier_test.sh deleted file mode 100644 index e8351b1..0000000 --- a/slurm_scripts/classifier_test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -#SBATCH --account vjgo8416-spice -#SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_classifier -#SBATCH --time 0-12:0:0 -#SBATCH --nodes 1 -#SBATCH --gpus 1 -#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_classifier-%j.out -#SBATCH --array=0-0 -#SBATCH --cpus-per-gpu 18 - - -# Load required modules here -module purge -module load baskerville -module load bask-apps/live/live -module load Python/3.10.8-GCCcore-12.2.0 - - -# change working directory -cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ - -source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate - -# change huggingface cache to be in project dir rather than user home -export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" - -# TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test classifier \ No newline at end of file diff --git a/slurm_scripts/classifier_test_short.sh b/slurm_scripts/classifier_test_short.sh deleted file mode 100644 index a59b8a7..0000000 --- a/slurm_scripts/classifier_test_short.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -#SBATCH --account vjgo8416-spice -#SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_classifier -#SBATCH --time 0-0:12:0 -#SBATCH --nodes 1 -#SBATCH --gpus 1 -#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_classifier-%j.out -#SBATCH --array=0-0 -#SBATCH --cpus-per-gpu 18 - - -# Load required modules here -module purge -module load baskerville -module load bask-apps/live/live -module load Python/3.10.8-GCCcore-12.2.0 - - -# change working directory -cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ - -source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate - -# change huggingface cache to be in project dir rather than user home -export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" - -# TODO: script uses relative path to project home so must be run from home, fix -python single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test classifier \ No newline at end of file diff --git a/slurm_scripts/ocr_test.sh b/slurm_scripts/ocr_test.sh deleted file mode 100644 index 0146785..0000000 --- a/slurm_scripts/ocr_test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -#SBATCH --account vjgo8416-spice -#SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_ocr -#SBATCH --time 0-12:0:0 -#SBATCH --nodes 1 -#SBATCH --gpus 1 -#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_ocr-%j.out -#SBATCH --array=0-0 -#SBATCH --cpus-per-gpu 18 - - -# Load required modules here -module purge -module load baskerville -module load bask-apps/live/live -module load Python/3.10.8-GCCcore-12.2.0 - - -# change working directory -cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ - -source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate - -# change huggingface cache to be in project dir rather than user home -export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" - -# TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test ocr \ No newline at end of file diff --git a/slurm_scripts/translator_test.sh b/slurm_scripts/translator_test.sh deleted file mode 100644 index 506603a..0000000 --- a/slurm_scripts/translator_test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -#SBATCH --account vjgo8416-spice -#SBATCH --qos turing -#SBATCH --job-name baskerville_pipeline_inference_test_translator -#SBATCH --time 0-24:0:0 -#SBATCH --nodes 1 -#SBATCH --gpus 1 -#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/baskerville_pipeline_inference_test_translator-%j.out -#SBATCH --array=0-0 -#SBATCH --cpus-per-gpu 18 - - -# Load required modules here -module purge -module load baskerville -module load bask-apps/live/live -module load Python/3.10.8-GCCcore-12.2.0 - - -# change working directory -cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ - -source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate - -# change huggingface cache to be in project dir rather than user home -export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" - -# TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/RTC_configs/roberta-mt5-zero-shot.yaml /bask/projects/v/vjgo8416-spice/ARC-SPICE/config/data_configs/l1_fr_to_en.yaml 42 baskerville_pipeline_inference_test translator \ No newline at end of file From c21e2d8d9ca09e9438ff44200d9a57110c924a8c Mon Sep 17 00:00:00 2001 From: tbc Date: Tue, 3 Dec 2024 16:19:54 +0000 Subject: [PATCH 11/14] made small change to the inference pipeline, which allow the experiment_result--seed combination directory to exist already, since this is caught in the jobscript creation I don't forsee this being an issue --- scripts/pipeline_inference.py | 3 ++- scripts/single_component_inference.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/pipeline_inference.py b/scripts/pipeline_inference.py index e001b78..834cdb6 100644 --- a/scripts/pipeline_inference.py +++ b/scripts/pipeline_inference.py @@ -32,7 +32,8 @@ def main( f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" f"{experiment_name}/seed_{seed}/" ) - os.makedirs(save_loc, exist_ok=False) + # This directory needs to exist for all 4 experiments + os.makedirs(save_loc, exist_ok=True) # seed experiment seed_everything(seed=seed) # initialise pipeline diff --git a/scripts/single_component_inference.py b/scripts/single_component_inference.py index ee14fe6..fb98747 100644 --- a/scripts/single_component_inference.py +++ b/scripts/single_component_inference.py @@ -52,7 +52,8 @@ def main( f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/" f"{experiment_name}/seed_{seed}/" ) - os.makedirs(save_loc, exist_ok=False) + # This directory needs to exist for all 4 experiments + os.makedirs(save_loc, exist_ok=True) # seed experiment seed_everything(seed=seed) # initialise pipeline From 8106e1cb2461025a9812d2a2f54f954e90f2b4a9 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Tue, 3 Dec 2024 16:21:33 +0000 Subject: [PATCH 12/14] removed another test bash script --- slurm_scripts/single_component_inference.sh | 27 --------------------- 1 file changed, 27 deletions(-) delete mode 100644 slurm_scripts/single_component_inference.sh diff --git a/slurm_scripts/single_component_inference.sh b/slurm_scripts/single_component_inference.sh deleted file mode 100644 index a76f23d..0000000 --- a/slurm_scripts/single_component_inference.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -#SBATCH --account vjgo8416-spice -#SBATCH --qos turing -#SBATCH --job-name SPICE_variational_RTC -#SBATCH --time 0-12:0:0 -#SBATCH --nodes 1 -#SBATCH --gpus 1 -#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/%j.out -#SBATCH --cpus-per-gpu 18 - -# Load required modules here -module purge -module load baskerville -module load bask-apps/live/live -module load Python/3.10.8-GCCcore-12.2.0 - - -# change working directory -cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/ - -source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate - -# change huggingface cache to be in project dir rather than user home -export HF_HOME="/bask/projects/v/vjgo8416-spice/hf_cache" - -# TODO: script uses relative path to project home so must be run from home, fix -python scripts/single_component_inference.py From 7ac88b8e2c4417e0afed9b8a2f43d92f0a112394 Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Tue, 3 Dec 2024 18:08:03 +0000 Subject: [PATCH 13/14] updated README.md in the /scripts folder --- scripts/README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/scripts/README.md b/scripts/README.md index 2992405..d5e8e74 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -106,3 +106,30 @@ It's called like so e.g. from project root: ```bash python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator ``` + +## gen_jobscripts.py + +Create jobscript `.sh` files for an experiment, which in this case refers to a `data_config` and `pipeline_config` combo. +It takes a single argument which is `experiment_config_path`. This refers to a file path to a `.yaml` file structured as below: + +### eg. Experiment config: + +```yaml +data_config: l1_fr_to_en + +pipeline_config: roberta-mt5-zero-shot + +seed: + - 42 + - 43 + - 44 + +bask: + jobname: "full_experiment_with_zero_shot" + walltime: '0-24:0:0' + gpu_number: 1 + node_number: 1 + hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache" + + +``` From 92a98f06617d5b5471e5b700ac694c26f68e63ef Mon Sep 17 00:00:00 2001 From: J-Dymond Date: Fri, 6 Dec 2024 11:36:53 +0000 Subject: [PATCH 14/14] addressed comments from pull request, also added additional outputs to translation to allow additional confidence measures --- src/arc_spice/eval/classification_error.py | 7 -- src/arc_spice/eval/inference_utils.py | 27 ++++--- .../RTC_single_component_pipeline.py | 21 ++---- .../RTC_variational_pipeline.py | 57 +------------- src/arc_spice/variational_pipelines/utils.py | 74 ++++++++++++++++++- tests/test_inference.py | 31 -------- 6 files changed, 99 insertions(+), 118 deletions(-) diff --git a/src/arc_spice/eval/classification_error.py b/src/arc_spice/eval/classification_error.py index 71292d3..2e8f36c 100644 --- a/src/arc_spice/eval/classification_error.py +++ b/src/arc_spice/eval/classification_error.py @@ -1,7 +1,4 @@ -import math - import torch -from sklearn.metrics import zero_one_loss def aggregate_score(probs: torch.Tensor) -> torch.Tensor: @@ -11,10 +8,6 @@ def aggregate_score(probs: torch.Tensor) -> torch.Tensor: return 1 - torch.mean(distance) -def zero_one_loss_ceil(y_target, y_pred): - return math.ceil(zero_one_loss(y_target, y_pred, normalize=True)) - - def MC_dropout_scores( variational_probs: list[float], epsilon: float = 1e-14 ) -> dict[str, torch.Tensor]: diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 27513cf..bc8c774 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -17,15 +17,7 @@ ) RecognitionResults = namedtuple("RecognitionResults", ["confidence", "accuracy"]) -ClassificationResults = namedtuple( - "ClassificationResults", - [ - "clean_scores", - "mean_scores", - "hamming_accuracy", - "mean_predicted_entropy", - ], -) + TranslationResults = namedtuple( "TranslationResults", [ @@ -33,6 +25,18 @@ "clean_conditional_probability", "comet_score", "weighted_semantic_density", + "mean_entropy", + "sequence_lengths", + ], +) + +ClassificationResults = namedtuple( + "ClassificationResults", + [ + "clean_scores", + "mean_scores", + "hamming_accuracy", + "mean_predicted_entropy", ], ) @@ -79,6 +83,8 @@ def translation_results( source_text = test_row["target_text"] target_text = test_row["target_text"] clean_translation = clean_output["translation"]["full_output"] + clean_entropy: torch.Tensor = clean_output["translation"]["mean_entropy"] + seq_lens: torch.Tensor = var_output["translation"]["sequence_length"] probs: list[torch.Tensor] = clean_output["translation"]["probs"] clean_cond_prob = [ conditional_probability(prob.squeeze()).detach().tolist() for prob in probs @@ -102,6 +108,8 @@ def translation_results( comet_score=comet_output["scores"][0], full_output=clean_translation, clean_conditional_probability=clean_cond_prob, + mean_entropy=clean_entropy, + sequence_lengths=seq_lens, weighted_semantic_density=var_output["translation"][ "weighted_semantic_density" ], @@ -144,4 +152,5 @@ def run_inference( test_row=inp, ) results.append({inp["celex_id"]: row_results_dict}) + break return results diff --git a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py index 0be0025..a0788d1 100644 --- a/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_single_component_pipeline.py @@ -4,10 +4,14 @@ from transformers import pipeline from arc_spice.variational_pipelines.RTC_variational_pipeline import ( - CustomTranslationPipeline, RTCVariationalPipelineBase, ) -from arc_spice.variational_pipelines.utils import dropout_off, dropout_on, set_dropout +from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, + dropout_off, + dropout_on, + set_dropout, +) class RTCSingleComponentPipeline(RTCVariationalPipelineBase): @@ -34,19 +38,6 @@ def __init__( # define objects that are needed and nothing else # naive outputs can remain the same, though only the appropriate outputs will # be outputted - self.naive_outputs = { - "recognition": [ - "outputs", - ], - "translation": [ - "full_output", - "outputs", - "probs", - ], - "classification": [ - "scores", - ], - } self.step_name = step_name self.input_key = input_key self.forward_function = forward_function diff --git a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py index 15bf990..2d9f1ea 100644 --- a/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py +++ b/src/arc_spice/variational_pipelines/RTC_variational_pipeline.py @@ -1,11 +1,10 @@ -import copy from typing import Any import torch -from torch.nn.functional import softmax -from transformers import TranslationPipeline, pipeline +from transformers import pipeline from arc_spice.variational_pipelines.utils import ( + CustomTranslationPipeline, RTCVariationalPipelineBase, dropout_off, dropout_on, @@ -134,55 +133,3 @@ def variational_inference(self, x: torch.Tensor) -> tuple[dict, dict]: # on standard call return the clean output def __call__(self, x): return self.clean_inference(x) - - -# Translation pipeline with additional functionality to save logits from fwd pass -class CustomTranslationPipeline(TranslationPipeline): - """ - custom translation pipeline to return the logits with the generated text. Largely - the same as the pytorch version with some additional arguments passed to the - `generate` method. - """ - - def postprocess( - self, - model_outputs: dict, - **postprocess_params, - ): - # model_outputs gets overwritten in the super().postprocess call - # make a copy here so we retain the information we want - raw_out = copy.deepcopy(model_outputs) - processed = super().postprocess(model_outputs, **postprocess_params) - - return { - "translation_text": processed[0]["translation_text"], - "raw_outputs": raw_out, - } - - def _forward(self, model_inputs, **generate_kwargs): - if self.framework == "pt": - in_b, input_length = model_inputs["input_ids"].shape - elif self.framework == "tf": - raise NotImplementedError - - self.check_inputs( - input_length, - generate_kwargs.get("min_length", self.model.config.min_length), - generate_kwargs.get("max_length", self.model.config.max_length), - ) - out = self.model.generate(**model_inputs, **generate_kwargs) - output_ids = out["sequences"] - out_b = output_ids.shape[0] - if self.framework == "pt": - output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) - elif self.framework == "tf": - raise NotImplementedError - - # logits are a tuple of length output_ids[-1]-1 - # each element is a tensor of shape (batch_size, vocab_size) - logits = torch.stack(out["logits"], dim=1) - # get softmax of the logits to get token probabilities - softmax_logits = softmax(logits, dim=-1) - max_token_scores = torch.max(softmax_logits, dim=-1).values - - return {"output_ids": output_ids, "scores": max_token_scores} diff --git a/src/arc_spice/variational_pipelines/utils.py b/src/arc_spice/variational_pipelines/utils.py index 1de7427..f37ca30 100644 --- a/src/arc_spice/variational_pipelines/utils.py +++ b/src/arc_spice/variational_pipelines/utils.py @@ -1,11 +1,18 @@ +import copy import logging +import math from abc import ABC, abstractmethod from functools import partial from typing import Any import torch from torch.nn.functional import softmax -from transformers import AutoModelForSequenceClassification, AutoTokenizer, Pipeline +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + Pipeline, + TranslationPipeline, +) logger = logging.Logger("RTC_variational_pipeline") @@ -117,6 +124,7 @@ def __init__(self, n_variational_runs=5, translation_batch_size=8): "full_output", "outputs", "probs", + "mean_entropy", ], "classification": [ "scores", @@ -264,6 +272,9 @@ def translate(self, text: str) -> dict[str, torch.Tensor | str]: { "outputs": translator_output["translation_text"], "probs": translator_output["raw_outputs"]["scores"], + "mean_entropy": torch.mean(translator_output["raw_outputs"]["entropy"]) + .detach() + .tolist(), } for translator_output in translator_outputs ] @@ -430,6 +441,7 @@ def translation_semantic_density( { "semantic_densities": densities, "weighted_semantic_density": weighted_average.item(), + "sequence_length": sequence_lengths, } ) @@ -480,3 +492,63 @@ def get_classification_confidence( } ) return var_output + + +# Translation pipeline with additional functionality to save logits from fwd pass +class CustomTranslationPipeline(TranslationPipeline): + """ + custom translation pipeline to return the logits with the generated text. Largely + the same as the pytorch version with some additional arguments passed to the + `generate` method. + """ + + def postprocess( + self, + model_outputs: dict, + **postprocess_params, + ): + # model_outputs gets overwritten in the super().postprocess call + # make a copy here so we retain the information we want + raw_out = copy.deepcopy(model_outputs) + processed = super().postprocess(model_outputs, **postprocess_params) + + return { + "translation_text": processed[0]["translation_text"], + "raw_outputs": raw_out, + } + + def _forward(self, model_inputs, **generate_kwargs): + if self.framework == "pt": + in_b, input_length = model_inputs["input_ids"].shape + elif self.framework == "tf": + raise NotImplementedError + + self.check_inputs( + input_length, + generate_kwargs.get("min_length", self.model.config.min_length), + generate_kwargs.get("max_length", self.model.config.max_length), + ) + out = self.model.generate(**model_inputs, **generate_kwargs) + output_ids = out["sequences"] + out_b = output_ids.shape[0] + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + raise NotImplementedError + + # logits are a tuple of length output_ids[-1]-1 + # each element is a tensor of shape (batch_size, vocab_size) + logits = torch.stack(out["logits"], dim=1) + # get softmax of the logits to get token probabilities + softmax_logits = softmax(logits, dim=-1) + vocab_size = softmax_logits.shape[-1] + normalised_entropy = torch.distributions.Categorical( + probs=softmax_logits + ).entropy() / math.log(vocab_size) + max_token_scores = torch.max(softmax_logits, dim=-1).values + + return { + "output_ids": output_ids, + "scores": max_token_scores, + "entropy": normalised_entropy, + } diff --git a/tests/test_inference.py b/tests/test_inference.py index 19fa240..a7e6b08 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -2,9 +2,7 @@ from unittest.mock import MagicMock, patch import pytest -from sklearn.metrics import hamming_loss -from arc_spice.eval.classification_error import zero_one_loss_ceil from arc_spice.utils import open_yaml_path from arc_spice.variational_pipelines.RTC_single_component_pipeline import ( ClassificationVariationalPipeline, @@ -42,35 +40,6 @@ def dummy_metadata(): } -def test_errors(): - dummy_target = [0, 1, 0, 1, 0] - dummy_middle_output = [1, 1, 0, 1, 0] - - assert hamming_loss(dummy_target, dummy_middle_output) == pytest.approx( - 0.2, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_middle_output) == pytest.approx( - 1.0, abs=1e-5 - ) - - dummy_correct_output = [0, 1, 0, 1, 0] - - assert hamming_loss(dummy_target, dummy_correct_output) == pytest.approx( - 0.0, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_correct_output) == pytest.approx( - 0.0, abs=1e-5 - ) - - dummy_incorrect_output = [1, 0, 1, 0, 1] - assert hamming_loss(dummy_target, dummy_incorrect_output) == pytest.approx( - 1.0, abs=1e-5 - ) - assert zero_one_loss_ceil(dummy_target, dummy_incorrect_output) == pytest.approx( - 1.0, abs=1e-5 - ) - - def test_pipeline_inputs(dummy_data, dummy_metadata): pipeline_config = open_yaml_path(PIPELINE_PATH)