Skip to content

Commit

Permalink
Sample id fix in GroupAnalysis (#384)
Browse files Browse the repository at this point in the history
* allow skipping stats in GroupAnalysis metric

* fix sample ids in collector

* support tuple ids in GroupAnalysis

* revetring previous fix

---------

Co-authored-by: Sivan Ravid <[email protected]>
  • Loading branch information
sivanravidos and Sivan Ravid authored Dec 9, 2024
1 parent a42be01 commit 788b374
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions fuse/eval/metrics/metrics_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,6 @@ def get(self, ids: Optional[Sequence[Hashable]] = None) -> Tuple[Dict[str, Any]]

permutation = [original_ids_pos[sample_id] for sample_id in required_ids]

# create the permuted dictionary
data = {}
for name, values in self._collected_data.items():
data[name] = [values[i] for i in permutation]
Expand Down Expand Up @@ -741,15 +740,17 @@ def eval(
raise Exception(
"Error: group analysis is supported only when a unique identifier is specified. Add key 'id' to your data"
)
ids = np.array(ids)
# don't convert to array, this converts the tuple ids to arrays
# ids = np.array(ids)

groups = np.array(data["group"])
unique_groups = set(groups)

group_analysis_results = {}
for group_value in unique_groups:
group_ids = ids[groups == group_value]

# group_ids = ids[groups == group_value]
# use List comprehension instead of boolean filtering to support advance id types such as tuples
group_ids = [ids[i] for i in range(len(ids)) if groups[i] == group_value]
group_analysis_results[str(group_value)] = self._metric.eval(
results, group_ids
)
Expand Down

0 comments on commit 788b374

Please sign in to comment.