-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refreactored regression metrics to new dir
- Loading branch information
1 parent
8aed967
commit da7b348
Showing
4 changed files
with
134 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from typing import List, Optional, Union | ||
from fuse.eval.metrics.libs.stat import Stat | ||
from fuse.eval.metrics.metrics_common import MetricDefault | ||
import numpy as np | ||
from sklearn.metrics import mean_absolute_error, mean_squared_error | ||
|
||
|
||
class MetricPearsonCorrelation(MetricDefault): | ||
def __init__( | ||
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict | ||
) -> None: | ||
super().__init__( | ||
pred=pred, | ||
target=target, | ||
mask=mask, | ||
metric_func=Stat.pearson_correlation, | ||
**kwargs, | ||
) | ||
|
||
|
||
class MetricSpearmanCorrelation(MetricDefault): | ||
def __init__( | ||
self, pred: str, target: str, mask: Optional[str] = None, **kwargs: dict | ||
) -> None: | ||
super().__init__( | ||
pred=pred, | ||
target=target, | ||
mask=mask, | ||
metric_func=Stat.spearman_correlation, | ||
**kwargs, | ||
) | ||
|
||
|
||
class MetricMAE(MetricDefault): | ||
def __init__( | ||
self, | ||
pred: str, | ||
target: str, | ||
**kwargs: dict, | ||
) -> None: | ||
""" | ||
See MetricDefault for the missing params | ||
:param pred: scalar predictions | ||
:param target: ground truth scalar labels | ||
:param threshold: threshold to apply to both pred and target | ||
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy. | ||
""" | ||
super().__init__( | ||
pred=pred, | ||
target=target, | ||
metric_func=self.mae, | ||
**kwargs, | ||
) | ||
|
||
def mae( | ||
self, | ||
pred: Union[List, np.ndarray], | ||
target: Union[List, np.ndarray], | ||
**kwargs: dict, | ||
) -> float: | ||
return mean_absolute_error(y_true=target, y_pred=pred) | ||
|
||
|
||
class MetricMSE(MetricDefault): | ||
def __init__( | ||
self, | ||
pred: str, | ||
target: str, | ||
**kwargs: dict, | ||
) -> None: | ||
""" | ||
Our implementation of standard MSE, current version of scikit dones't support it as a metric. | ||
See MetricDefault for the missing params | ||
:param pred: scalar predictions | ||
:param target: ground truth scalar labels | ||
:param threshold: threshold to apply to both pred and target | ||
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy. | ||
""" | ||
super().__init__( | ||
pred=pred, | ||
target=target, | ||
metric_func=self.mse, | ||
**kwargs, | ||
) | ||
|
||
def mse( | ||
self, | ||
pred: Union[List, np.ndarray], | ||
target: Union[List, np.ndarray], | ||
**kwargs: dict, | ||
) -> float: | ||
return mean_squared_error(y_true=target, y_pred=pred) | ||
|
||
|
||
class MetricRMSE(MetricDefault): | ||
def __init__( | ||
self, | ||
pred: str, | ||
target: str, | ||
**kwargs: dict, | ||
) -> None: | ||
""" | ||
See MetricDefault for the missing params | ||
:param pred: scalar predictions | ||
:param target: ground truth scalar labels | ||
:param threshold: threshold to apply to both pred and target | ||
:param balanced: optionally to use balanced accuracy (from sklearn) instead of regular accuracy. | ||
""" | ||
super().__init__( | ||
pred=pred, | ||
target=target, | ||
metric_func=self.mse, | ||
**kwargs, | ||
) | ||
|
||
def mse( | ||
self, | ||
pred: Union[List, np.ndarray], | ||
target: Union[List, np.ndarray], | ||
**kwargs: dict, | ||
) -> float: | ||
|
||
pred = np.array(pred).flatten() | ||
target = np.array(target).flatten() | ||
|
||
assert len(pred) == len( | ||
target | ||
), f"Expected pred and target to have the dimensions but found: {len(pred)} elements in pred and {len(target)} in target" | ||
|
||
squared_diff = (pred - target) ** 2 | ||
return squared_diff.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters