Skip to content

Commit

Permalink
addressed some comments from pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 27, 2024
1 parent 3db7f75 commit 7829244
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
File renamed without changes.
37 changes: 25 additions & 12 deletions src/arc_spice/eval/inference_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Callable
from typing import Any

import torch
from sklearn.metrics import hamming_loss, zero_one_loss
Expand All @@ -16,8 +17,8 @@


class ResultsGetter:
def __init__(self, n_classes):
self.func_map: dict[str, Callable] = {
def __init__(self, n_classes: int):
self.results_func_map: dict[str, Callable] = {
"recognition": self.recognition_results,
"translation": self.translation_results,
"classification": self.classification_results,
Expand All @@ -33,27 +34,30 @@ def get_results(
results_dict: dict[str, dict[str, list]],
):
for step_name in clean_output:
results_dict = self.func_map[step_name](
results_dict = self.results_func_map[step_name](
test_row, clean_output, var_output, results_dict
)
results_dict["input_data"]["celex_ids"].append(test_row["celex_id"])
return results_dict

def recognition_results(self, test_row, clean_output, var_output, results_dict):
assert test_row is not None
assert clean_output is not None
assert var_output is not None
def recognition_results(self):
# ### RECOGNITION ###
# TODO: add this into results_getter
return results_dict
# TODO: add this into results_getter issue #14
raise NotImplementedError()

def translation_results(self, test_row, clean_output, var_output, results_dict):
def translation_results(
self,
test_row: dict[str, Any],
clean_output: dict[str, dict],
var_output: dict[str, dict],
results_dict: dict[str, dict],
):
# ### TRANSLATION ###
source_text = test_row["target_text"]
target_text = test_row["target_text"]
clean_translation = clean_output["translation"]["full_output"]

# load error model
# define error model inputs
comet_inp = [
{
"src": source_text,
Expand All @@ -74,7 +78,13 @@ def translation_results(self, test_row, clean_output, var_output, results_dict):
)
return results_dict

def classification_results(self, test_row, _, var_output, results_dict):
def classification_results(
self,
test_row: dict[str, Any],
_: dict[str, dict],
var_output: dict[str, dict],
results_dict: dict[str, dict],
):
# ### CLASSIFICATION ###
mean_scores = var_output["classification"]["mean_scores"]
preds = torch.round(mean_scores).tolist()
Expand All @@ -98,6 +108,9 @@ def run_inference(
pipeline: RTCVariationalPipeline | RTCSingleComponentPipeline,
results_getter: ResultsGetter,
):
# Get_results updates the results_dict. So it needs to be initialised before being
# run. It is overwritten if a RTCSingleComponentPipeline is used, since some entries
# are not needed.
results_dict = {
"input_data": {"celex_ids": []},
# Placeholder
Expand Down

0 comments on commit 7829244

Please sign in to comment.