diff --git a/config/RTC_configs/bert-mt5-bert.yaml b/config/RTC_configs/bert-mt5-zero-shot.yaml similarity index 100% rename from config/RTC_configs/bert-mt5-bert.yaml rename to config/RTC_configs/bert-mt5-zero-shot.yaml diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index b8ea916..49c53ac 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any import torch from sklearn.metrics import hamming_loss, zero_one_loss @@ -16,8 +17,8 @@ class ResultsGetter: - def __init__(self, n_classes): - self.func_map: dict[str, Callable] = { + def __init__(self, n_classes: int): + self.results_func_map: dict[str, Callable] = { "recognition": self.recognition_results, "translation": self.translation_results, "classification": self.classification_results, @@ -33,27 +34,30 @@ def get_results( results_dict: dict[str, dict[str, list]], ): for step_name in clean_output: - results_dict = self.func_map[step_name]( + results_dict = self.results_func_map[step_name]( test_row, clean_output, var_output, results_dict ) results_dict["input_data"]["celex_ids"].append(test_row["celex_id"]) return results_dict - def recognition_results(self, test_row, clean_output, var_output, results_dict): - assert test_row is not None - assert clean_output is not None - assert var_output is not None + def recognition_results(self): # ### RECOGNITION ### - # TODO: add this into results_getter - return results_dict + # TODO: add this into results_getter issue #14 + raise NotImplementedError() - def translation_results(self, test_row, clean_output, var_output, results_dict): + def translation_results( + self, + test_row: dict[str, Any], + clean_output: dict[str, dict], + var_output: dict[str, dict], + results_dict: dict[str, dict], + ): # ### TRANSLATION ### source_text = test_row["target_text"] target_text = test_row["target_text"] clean_translation = clean_output["translation"]["full_output"] - # load error model + # define error model inputs comet_inp = [ { "src": source_text, @@ -74,7 +78,13 @@ def translation_results(self, test_row, clean_output, var_output, results_dict): ) return results_dict - def classification_results(self, test_row, _, var_output, results_dict): + def classification_results( + self, + test_row: dict[str, Any], + _: dict[str, dict], + var_output: dict[str, dict], + results_dict: dict[str, dict], + ): # ### CLASSIFICATION ### mean_scores = var_output["classification"]["mean_scores"] preds = torch.round(mean_scores).tolist() @@ -98,6 +108,9 @@ def run_inference( pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline, results_getter: ResultsGetter, ): + # Get_results updates the results_dict. So it needs to be initialised before being + # run. It is overwritten if a RTCSingleComponentPipeline is used, since some entries + # are not needed. results_dict = { "input_data": {"celex_ids": []}, # Placeholder