Skip to content

Commit

Permalink
updated inference outputs, now include input IDs as well to allow mat…
Browse files Browse the repository at this point in the history
…ching to original inputs
  • Loading branch information
J-Dymond committed Nov 26, 2024
1 parent 4539bdd commit 3db7f75
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,15 @@ def get_results(
clean_output: dict[str, dict],
var_output: dict[str, dict],
test_row: dict[str, list[int]],
results_dict: dict[str, dict[str, list[float]]],
results_dict: dict[str, dict[str, list]],
):
for step_name in clean_output:
results_dict = self.func_map[step_name](
test_row, clean_output, var_output, results_dict
)

results_dict["input_data"]["celex_ids"].append(test_row["celex_id"])
return results_dict

def single_step_results(
self,
clean_output: dict[str, dict],
var_output: dict[str, dict],
test_row: dict[str, list[int]],
results_dict: dict[str, dict[str, list[float]]],
step_name: str,
):
return self.func_map[step_name](
test_row, clean_output, var_output, results_dict
)

def recognition_results(self, test_row, clean_output, var_output, results_dict):
assert test_row is not None
assert clean_output is not None
Expand Down Expand Up @@ -79,6 +67,7 @@ def translation_results(self, test_row, clean_output, var_output, results_dict):
comet_inp, batch_size=8, accelerator=comet_device
)
comet_scores = comet_output["scores"]
results_dict["translation"]["full_output"].append(clean_translation)
results_dict["translation"]["comet_score"].append(comet_scores[0])
results_dict["translation"]["weighted_semantic_density"].append(
var_output["translation"]["weighted_semantic_density"]
Expand All @@ -92,6 +81,9 @@ def classification_results(self, test_row, _, var_output, results_dict):
labels = self.multihot(test_row["labels"])
hamming_acc = hamming_loss(y_pred=preds, y_true=labels)
zero_one_acc = zero_one_loss(y_pred=preds, y_true=labels)
results_dict["classification"]["mean_scores"].append(
mean_scores.detach().tolist()
)
results_dict["classification"]["hamming_accuracy"].append(hamming_acc)
results_dict["classification"]["zero_one_accuracy"].append(zero_one_acc)
results_dict["classification"]["mean_predicted_entropy"].append(
Expand All @@ -107,17 +99,27 @@ def run_inference(
results_getter: ResultsGetter,
):
results_dict = {
"input_data": {"celex_ids": []},
# Placeholder
"ocr": {"confidence": [], "accuracy": []},
"translation": {"weighted_semantic_density": [], "comet_score": []},
"ocr": {"outputs": [], "confidence": [], "accuracy": []}, # PLACEHOLDER
"translation": {
"full_output": [],
"weighted_semantic_density": [],
"comet_score": [],
},
"classification": {
"mean_scores": [],
"mean_predicted_entropy": [],
"hamming_accuracy": [],
"zero_one_accuracy": [],
},
}
if isinstance(pipeline, RTCSingleComponentPipeline):
results_dict = {pipeline.step_name: results_dict[pipeline.step_name]}
# only need appropriate result dict when evaluating individual component
results_dict = {
"input_data": {"celex_ids": []},
pipeline.step_name: results_dict[pipeline.step_name],
}

for _, inp in enumerate(tqdm(dataloader)):
clean_out, var_out = pipeline.variational_inference(inp)
Expand Down

0 comments on commit 3db7f75

Please sign in to comment.