From 5fc4b91f27de368b6329e211dc1533b5c095225b Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:50:41 +0000 Subject: [PATCH] Function to handle loss functions with or without sample_weight arg --- .../_variable_importance/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 0f8de5ad..47a72d7c 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -82,3 +82,18 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari loss_features['_baseline_'] = loss_baseline return pd.DataFrame(loss_features, index=[0]) + + +def calculate_loss(loss_function, observed, predicted, sample_weights=None): + # Determine if loss function accepts 'sample_weight' + loss_args = inspect.signature(loss_function).parameters + supports_weight = "sample_weight" in loss_args + + if supports_weight: + return loss_function(observed, predicted, sample_weight=sample_weights) + else: + if sample_weights: + warnings.warn( + f"Loss function {loss_function.__name__} does not take sample weights. Calculating unweighted loss." + ) + return loss_function(observed, predicted)