Skip to content

Commit

Permalink
Merge branch 'master' into add_isort_as_a_pre_commit_hook
Browse files Browse the repository at this point in the history
  • Loading branch information
RomiPolaczek authored Dec 9, 2024
2 parents 9d0f67e + 788b374 commit eea8dd5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
3 changes: 1 addition & 2 deletions fuse/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
OpApplyPatterns,
OpFunc,
OpKeepKeypaths,
OpLambda,
OpRepeat,
OpSet,
)
from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.pipelines.pipeline_default import PipelineDefault
Expand Down
68 changes: 40 additions & 28 deletions fuse/eval/metrics/metrics_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ def get(self, ids: Optional[Sequence[Hashable]] = None) -> Tuple[Dict[str, Any]]

permutation = [original_ids_pos[sample_id] for sample_id in required_ids]

# create the permuted dictionary
data = {}
for name, values in self._collected_data.items():
data[name] = [values[i] for i in permutation]
Expand Down Expand Up @@ -688,19 +687,27 @@ def eval(

class GroupAnalysis(MetricWithCollectorBase):
"""
Evaluate a metric per group and compute basic statistics about the different per group results.
Evaluate a metric per group and compute basic statistics over the different per group results.
eval() method returns a dictionary of the following format:
{'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...}
"""

def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None:
def __init__(
self,
metric: MetricBase,
group: str,
compute_group_stats: bool = True,
**super_kwargs: Any,
) -> None:
"""
:param metric: metric to analyze
:param group: key to extract the group from
:compute_group_stats: wether to compute stats such as mean, std, median over the per group results
:param super_kwargs: additional arguments for super class (MetricWithCollectorBase) constructor
"""
super().__init__(group=group, **super_kwargs)
self._metric = metric
self._compute_group_stats = compute_group_stats

def collect(self, batch: Dict) -> None:
"See super class"
Expand All @@ -718,7 +725,9 @@ def reset(self) -> None:
return super().reset()

def eval(
self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None
self,
results: Dict[str, Any] = None,
ids: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Any]:
"""
See super class
Expand All @@ -731,45 +740,48 @@ def eval(
raise Exception(
"Error: group analysis is supported only when a unique identifier is specified. Add key 'id' to your data"
)
ids = np.array(ids)
# don't convert to array, this converts the tuple ids to arrays
# ids = np.array(ids)

groups = np.array(data["group"])
unique_groups = set(groups)

group_analysis_results = {}
for group_value in unique_groups:
group_ids = ids[groups == group_value]

# group_ids = ids[groups == group_value]
# use List comprehension instead of boolean filtering to support advance id types such as tuples
group_ids = [ids[i] for i in range(len(ids)) if groups[i] == group_value]
group_analysis_results[str(group_value)] = self._metric.eval(
results, group_ids
)

# compute stats
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
if self._compute_group_stats:
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass

return group_analysis_results

Expand Down

0 comments on commit eea8dd5

Please sign in to comment.