diff --git a/src/arc_spice/eval/inference_utils.py b/src/arc_spice/eval/inference_utils.py index 871d890..b8ea916 100644 --- a/src/arc_spice/eval/inference_utils.py +++ b/src/arc_spice/eval/inference_utils.py @@ -30,27 +30,15 @@ def get_results( clean_output: dict[str, dict], var_output: dict[str, dict], test_row: dict[str, list[int]], - results_dict: dict[str, dict[str, list[float]]], + results_dict: dict[str, dict[str, list]], ): for step_name in clean_output: results_dict = self.func_map[step_name]( test_row, clean_output, var_output, results_dict ) - + results_dict["input_data"]["celex_ids"].append(test_row["celex_id"]) return results_dict - def single_step_results( - self, - clean_output: dict[str, dict], - var_output: dict[str, dict], - test_row: dict[str, list[int]], - results_dict: dict[str, dict[str, list[float]]], - step_name: str, - ): - return self.func_map[step_name]( - test_row, clean_output, var_output, results_dict - ) - def recognition_results(self, test_row, clean_output, var_output, results_dict): assert test_row is not None assert clean_output is not None @@ -79,6 +67,7 @@ def translation_results(self, test_row, clean_output, var_output, results_dict): comet_inp, batch_size=8, accelerator=comet_device ) comet_scores = comet_output["scores"] + results_dict["translation"]["full_output"].append(clean_translation) results_dict["translation"]["comet_score"].append(comet_scores[0]) results_dict["translation"]["weighted_semantic_density"].append( var_output["translation"]["weighted_semantic_density"] @@ -92,6 +81,9 @@ def classification_results(self, test_row, _, var_output, results_dict): labels = self.multihot(test_row["labels"]) hamming_acc = hamming_loss(y_pred=preds, y_true=labels) zero_one_acc = zero_one_loss(y_pred=preds, y_true=labels) + results_dict["classification"]["mean_scores"].append( + mean_scores.detach().tolist() + ) results_dict["classification"]["hamming_accuracy"].append(hamming_acc) results_dict["classification"]["zero_one_accuracy"].append(zero_one_acc) results_dict["classification"]["mean_predicted_entropy"].append( @@ -107,17 +99,27 @@ def run_inference( results_getter: ResultsGetter, ): results_dict = { + "input_data": {"celex_ids": []}, # Placeholder - "ocr": {"confidence": [], "accuracy": []}, - "translation": {"weighted_semantic_density": [], "comet_score": []}, + "ocr": {"outputs": [], "confidence": [], "accuracy": []}, # PLACEHOLDER + "translation": { + "full_output": [], + "weighted_semantic_density": [], + "comet_score": [], + }, "classification": { + "mean_scores": [], "mean_predicted_entropy": [], "hamming_accuracy": [], "zero_one_accuracy": [], }, } if isinstance(pipeline, RTCSingleComponentPipeline): - results_dict = {pipeline.step_name: results_dict[pipeline.step_name]} + # only need appropriate result dict when evaluating individual component + results_dict = { + "input_data": {"celex_ids": []}, + pipeline.step_name: results_dict[pipeline.step_name], + } for _, inp in enumerate(tqdm(dataloader)): clean_out, var_out = pipeline.variational_inference(inp)