diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index bad41dca..8939215f 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -1508,7 +1508,7 @@ def process_results(self, doc, results, full_docs=None): gold = type(result)(gold) for metric in self._metric_fn_list.keys(): - if self.multiple_target: + if self.multiple_target and metric != "anls": # in the case where we have multiple targets, # return true if any are true # TODO: this may break for multipLe_target, non zero-or-1 metrics @@ -1535,9 +1535,11 @@ def process_results(self, doc, results, full_docs=None): else: result_score = 0.0 else: + if not isinstance(gold, list): + gold = [gold] try: result_score = self._metric_fn_list[metric]( - references=[gold], + references=gold, predictions=[result], **self._metric_fn_kwargs[metric], )