Skip to content

Commit

Permalink
Matching keys lengths and new metrics (#358)
Browse files Browse the repository at this point in the history
* initial commit

* added post batch collate ops to colate funciton

* addded mathews corr coef and balanced acc metrics

* fixed type annotation errors and formatting

* formatting

* added documentation for metrics

---------

Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
Co-authored-by: Ido Amos [email protected] <[email protected]>
  • Loading branch information
7 people authored Jun 30, 2024
1 parent 8905d68 commit d690708
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
27 changes: 27 additions & 0 deletions fuse/data/utils/collates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
keep_keys: Sequence[str] = tuple(),
raise_error_key_missing: bool = True,
special_handlers_keys: Optional[Dict[str, Callable]] = None,
post_collate_special_handlers_keys: Optional[List[Callable]] = None,
add_to_batch_dict: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -51,6 +52,7 @@ def __init__(
:param special_handlers_keys: per key specify a callable which gets as an input list of values and convert it to a batch.
The rest of the keys will be converted to batch using PyTorch default collate_fn()
Example of such Callable can be seen in the CollateDefault.pad_all_tensors_to_same_size.
:param post_collate_special_handlers_keys: specify a callable which gets the batch_dict as an input and applies post processing to the batched tensors.
:param raise_error_key_missing: if False, will not raise an error if there are keys that do not exist in some of the samples. Instead will set those values to None.
:param add_to_batch_dict: optional, fixed items to add to batch_dict
"""
Expand All @@ -64,6 +66,8 @@ def __init__(
self._keep_keys = keep_keys
self._add_to_batch_dict = add_to_batch_dict

self._post_collate_special_handlers_keys = post_collate_special_handlers_keys

def __call__(self, samples: List[Dict]) -> Dict:
"""
collate list of samples into batch_dict
Expand Down Expand Up @@ -107,6 +111,10 @@ def __call__(self, samples: List[Dict]) -> Dict:
if self._add_to_batch_dict is not None:
batch_dict.update(self._add_to_batch_dict)

if self._post_collate_special_handlers_keys is not None:
for callable in self._post_collate_special_handlers_keys:
callable(batch_dict)

return batch_dict

def _batch_dispatch(
Expand Down Expand Up @@ -247,3 +255,22 @@ def crop_padding(
cropped_sequences = [ids[:min_length] for ids in input_ids_list]
batched_sequences = torch.stack(cropped_sequences, dim=0)
return batched_sequences

@staticmethod
def crop_length_to_match_target_key(
batch_dict: dict, target_key: str, keys_to_match: List[str]
) -> None:
"""
Match the 1st dimension (typically the sequence length) of tensors in batch_dict, specified by keys, to a target tensor.
Args:
:param batch_dict: input dictionary with batched tensors with expected shape (B, L, *) where B is the batch dimension and L is a sequential dimension. L may vary across the elements in the batch_dict. * is any number of additional dimensions.
:param target_key: key in batch_dict of the tensors with desired sequential dimension, L'.
:param keys_to_match: list of keys in batch_dict with tensors who's 1st dimension should be converted to match the target_key tensor. The tensors to match should have 1st dimenions L >= L'.
Returns:
None, elements are modified in the batch_dict.
"""
target_length = batch_dict[target_key].shape[1]
for key in keys_to_match:
batch_dict[key] = batch_dict[key][:, :target_length]
83 changes: 83 additions & 0 deletions fuse/eval/metrics/classification/metrics_classification_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from fuse.eval.metrics.metrics_common import MetricDefault, MetricWithCollectorBase
from fuse.eval.metrics.libs.classification import MetricsLibClass
from sklearn.metrics import matthews_corrcoef, balanced_accuracy_score


class MetricMultiClassDefault(MetricWithCollectorBase):
Expand Down Expand Up @@ -336,3 +337,85 @@ def __init__(self, pred: str, target: str, **kwargs: dict):
metric_func=MetricsLibClass.multi_class_bss,
**kwargs,
)


class MetricMCC(MetricDefault):
"""
Compute Mathews correlation coef for predictions
"""

def __init__(
self,
pred: Optional[str] = None,
target: Optional[str] = None,
sample_weight: Optional[str] = None,
**kwargs: dict,
):
"""
See MetricDefault for the missing params
:param pred: class label predictions
:param target: ground truth labels
:param sample_weight: weight per sample for the final accuracy score. Keep None if not required.
"""
super().__init__(
pred=pred,
target=target,
sample_weight=sample_weight,
metric_func=self.mcc_wrapper,
**kwargs,
)

def mcc_wrapper(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
sample_weight: Optional[Union[List, np.ndarray, None]] = None,
**kwargs: dict,
) -> float:
"""
for matching MetricDefault expected input format to that of sklearn
"""
res_dict = {"y_true": target, "y_pred": pred, "sample_weight": sample_weight}
score = matthews_corrcoef(**res_dict)
return score


class MetricBalAccuracy(MetricDefault):
"""
Compute Balanced accuracy for predictions
"""

def __init__(
self,
pred: Optional[str] = None,
target: Optional[str] = None,
sample_weight: Optional[str] = None,
**kwargs: dict,
):
"""
See MetricDefault for the missing params
:param pred: class label predictions
:param target: ground truth labels
:param sample_weight: weight per sample for the final accuracy score. Keep None if not required.
"""
super().__init__(
pred=pred,
target=target,
sample_weight=sample_weight,
metric_func=self.balanced_acc_wrapper,
**kwargs,
)

def balanced_acc_wrapper(
self,
pred: Union[List, np.ndarray],
target: Union[List, np.ndarray],
sample_weight: Optional[Union[List, np.ndarray, None]] = None,
**kwargs: dict,
) -> float:
"""
for matching MetricDefault expected input format to that of sklearn
"""
res_dict = {"y_true": target, "y_pred": pred, "sample_weight": sample_weight}
score = balanced_accuracy_score(**res_dict)
return score

0 comments on commit d690708

Please sign in to comment.