Skip to content

Commit

Permalink
Merge pull request #30 from alan-turing-institute/15-run-inference-on…
Browse files Browse the repository at this point in the history
…-baskerville

15 run inference on baskerville
  • Loading branch information
J-Dymond authored Dec 6, 2024
2 parents 2f15ada + 92a98f0 commit 0b79ed9
Show file tree
Hide file tree
Showing 16 changed files with 347 additions and 116 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Thumbs.db


slurm_scripts/slurm_logs*
slurm_scripts/experiments*
# other
temp
.vscode
Expand Down
2 changes: 1 addition & 1 deletion config/RTC_configs/roberta-mt5-zero-shot.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
OCR:
ocr:
specific_task: "image-to-text"
model: "microsoft/trocr-base-handwritten"

Expand Down
2 changes: 2 additions & 0 deletions config/data_configs/l1_fr_to_en.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ level: 1
lang_pair:
source: "fr"
target: "en"

drop_length: 1000
13 changes: 13 additions & 0 deletions config/experiment/baskerville_pipeline_inference_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42

bask:
jobname: "shortened_input_test"
walltime: '0-12:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
15 changes: 15 additions & 0 deletions config/experiment/full_experiment_zero_shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"
27 changes: 27 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,30 @@ It's called like so e.g. from project root:
```bash
python scripts/pipeline_inference.py [pipeline_config_path] [data_config_path] translator
```

## gen_jobscripts.py

Create jobscript `.sh` files for an experiment, which in this case refers to a `data_config` and `pipeline_config` combo.
It takes a single argument which is `experiment_config_path`. This refers to a file path to a `.yaml` file structured as below:

### eg. Experiment config:

```yaml
data_config: l1_fr_to_en

pipeline_config: roberta-mt5-zero-shot

seed:
- 42
- 43
- 44

bask:
jobname: "full_experiment_with_zero_shot"
walltime: '0-24:0:0'
gpu_number: 1
node_number: 1
hf_cache_dir: "/bask/projects/v/vjgo8416-spice/hf_cache"


```
82 changes: 82 additions & 0 deletions scripts/gen_jobscripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from pathlib import Path

from jinja2 import Environment, FileSystemLoader
from jsonargparse import CLI

from arc_spice.utils import open_yaml_path

PROJECT_DIR = Path(__file__, "..", "..").resolve()


def main(experiment_config_path: str):
"""
_summary_
Args:
experiment_config_path: _description_
"""
experiment_name = experiment_config_path.split("/")[-1].split(".")[0]
experiment_config = open_yaml_path(experiment_config_path)
pipeline_conf_dir = (
f"{PROJECT_DIR}/config/RTC_configs/{experiment_config['pipeline_config']}.yaml"
)
data_conf_dir = (
f"{PROJECT_DIR}/config/data_configs/{experiment_config['data_config']}.yaml"
)
pipeline_config = open_yaml_path(pipeline_conf_dir)
# Get jinja template
environment = Environment(
loader=FileSystemLoader(PROJECT_DIR / "src" / "arc_spice" / "config")
)
template = environment.get_template("jobscript_template.sh")
# We don't want to overwrite results

for index, seed in enumerate(experiment_config["seed"]):
os.makedirs(
f"slurm_scripts/experiments/{experiment_name}/run_{index}", exist_ok=False
)
for model in pipeline_config:
model_script_dict: dict = experiment_config["bask"]
model_script_dict.update(
{
"script_name": (
"scripts/single_component_inference.py "
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
f" {experiment_name} {model}"
),
"job_name": f"{experiment_name}_{model}",
"seed": seed,
}
)
model_train_script = template.render(model_script_dict)

with open(
f"slurm_scripts/experiments/{experiment_name}/run_{index}/{model}.sh",
"w",
) as f:
f.write(model_train_script)

pipeline_script_dict: dict = experiment_config["bask"]
pipeline_script_dict.update(
{
"script_name": (
"scripts/pipeline_inference.py "
f"{pipeline_conf_dir} {data_conf_dir} {seed}"
f" {experiment_name}"
),
"job_name": f"{experiment_name}_full_pipeline",
"seed": seed,
}
)
pipeline_train_script = template.render(pipeline_script_dict)

with open(
f"slurm_scripts/experiments/{experiment_name}/run_{index}/full_pipeline.sh",
"w",
) as f:
f.write(pipeline_train_script)


if __name__ == "__main__":
CLI(main)
28 changes: 19 additions & 9 deletions scripts/pipeline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,43 @@

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.utils import open_yaml_path, seed_everything
from arc_spice.variational_pipelines.RTC_variational_pipeline import (
RTCVariationalPipeline,
)

OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str):
def main(
pipeline_config_pth: str, data_config_pth: str, seed: int, experiment_name: str
):
"""
Run inference on a given pipeline with provided data config
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
seed: seed for the the inference pass
experiment_name: name of experiment for saving purposes
"""
# create save directory -> fail if already exists
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"{experiment_name}/seed_{seed}/"
)
# This directory needs to exist for all 4 experiments
os.makedirs(save_loc, exist_ok=True)
# seed experiment
seed_everything(seed=seed)
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
rtc_variational_pipeline = RTCVariationalPipeline(
model_pars=pipeline_config, data_pars=meta_data
Expand All @@ -37,11 +52,6 @@ def main(pipeline_config_pth: str, data_config_pth: str):
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}"
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/full_pipeline.json", "w") as save_file:
json.dump(test_results, save_file)

Expand Down
35 changes: 23 additions & 12 deletions scripts/single_component_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from jsonargparse import CLI

from arc_spice.data.multieurlex_utils import load_multieurlex_for_translation
from arc_spice.data.multieurlex_utils import load_multieurlex_for_pipeline
from arc_spice.eval.inference_utils import ResultsGetter, run_inference
from arc_spice.utils import open_yaml_path
from arc_spice.utils import open_yaml_path, seed_everything
from arc_spice.variational_pipelines.RTC_single_component_pipeline import (
ClassificationVariationalPipeline,
RecognitionVariationalPipeline,
Expand All @@ -28,19 +28,38 @@
OUTPUT_DIR = "outputs"


def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
def main(
pipeline_config_pth: str,
data_config_pth: str,
seed: int,
experiment_name: str,
model_key: str,
):
"""
Run inference on a given pipeline component with provided data config and model key.
Args:
pipeline_config_pth: path to pipeline config yaml file
data_config_pth: path to data config yaml file
seed: seed for the the inference pass
experiment_name: name of experiment for saving purposes
model_key: name of model on which to run inference
"""
# create save directory -> fail if already exists
data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"{experiment_name}/seed_{seed}/"
)
# This directory needs to exist for all 4 experiments
os.makedirs(save_loc, exist_ok=True)
# seed experiment
seed_everything(seed=seed)
# initialise pipeline
data_config = open_yaml_path(data_config_pth)
pipeline_config = open_yaml_path(pipeline_config_pth)
data_sets, meta_data = load_multieurlex_for_translation(**data_config)
data_sets, meta_data = load_multieurlex_for_pipeline(**data_config)
test_loader = data_sets["test"]
if model_key == "ocr":
rtc_single_component_pipeline = RecognitionVariationalPipeline(
Expand Down Expand Up @@ -69,14 +88,6 @@ def main(pipeline_config_pth: str, data_config_pth: str, model_key: str):
results_getter=results_getter,
)

data_name = data_config_pth.split("/")[-1].split(".")[0]
pipeline_name = pipeline_config_pth.split("/")[-1].split(".")[0]
save_loc = (
f"{OUTPUT_DIR}/inference_results/{data_name}/{pipeline_name}/"
f"single_component"
)
os.makedirs(save_loc, exist_ok=True)

with open(f"{save_loc}/{model_key}.json", "w") as save_file:
json.dump(test_results, save_file)

Expand Down
28 changes: 28 additions & 0 deletions src/arc_spice/config/jobscript_template.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
#SBATCH --account vjgo8416-spice
#SBATCH --qos turing
#SBATCH --job-name {{ job_name }}
#SBATCH --time {{ walltime }}
#SBATCH --nodes {{ node_number }}
#SBATCH --gpus {{ gpu_number }}
#SBATCH --output /bask/projects/v/vjgo8416-spice/ARC-SPICE/slurm_scripts/slurm_logs/{{ job_name }}-%j.out
#SBATCH --cpus-per-gpu 18


# Load required modules here
module purge
module load baskerville
module load bask-apps/live/live
module load Python/3.10.8-GCCcore-12.2.0


# change working directory
cd /bask/projects/v/vjgo8416-spice/ARC-SPICE/

source /bask/projects/v/vjgo8416-spice/ARC-SPICE/env/bin/activate

# change huggingface cache to be in project dir rather than user home
export HF_HOME="{{ hf_cache_dir }}"

# TODO: script uses relative path to project home so must be run from home, fix
python {{ script_name }}
13 changes: 12 additions & 1 deletion src/arc_spice/data/multieurlex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def load_multieurlex(
level: int,
languages: list[str],
drop_empty: bool = True,
drop_length: int | None = None,
split: str | None = None,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
"""
Expand Down Expand Up @@ -188,6 +189,11 @@ def load_multieurlex(
lambda x: all(x is not None for x in x["text"].values())
)

if drop_length:
dataset_dict = dataset_dict.filter(
lambda x: len(x["text"][languages[0]]) <= drop_length
)

# return datasets and metadata
return dataset_dict, metadata

Expand All @@ -197,11 +203,16 @@ def load_multieurlex_for_pipeline(
level: int,
lang_pair: dict[str, str],
drop_empty: bool = True,
drop_length: int | None = None,
load_ocr_data: bool = False,
) -> tuple[datasets.DatasetDict, dict[str, Any]]:
langs = [lang_pair["source"], lang_pair["target"]]
dataset_dict, meta_data = load_multieurlex(
data_dir=data_dir, level=level, languages=langs, drop_empty=drop_empty
data_dir=data_dir,
level=level,
languages=langs,
drop_empty=drop_empty,
drop_length=drop_length,
)
# instantiate the preprocessor
preprocesser = TranslationPreProcesser(lang_pair)
Expand Down
Loading

0 comments on commit 0b79ed9

Please sign in to comment.