Skip to content

Commit

Permalink
Function to handle loss functions with or without sample_weight arg
Browse files Browse the repository at this point in the history
  • Loading branch information
danielarifmurphy committed May 1, 2024
1 parent 5348a51 commit 5fc4b91
Showing 1 changed file with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari
loss_features['_baseline_'] = loss_baseline

return pd.DataFrame(loss_features, index=[0])


def calculate_loss(loss_function, observed, predicted, sample_weights=None):
# Determine if loss function accepts 'sample_weight'
loss_args = inspect.signature(loss_function).parameters
supports_weight = "sample_weight" in loss_args

if supports_weight:
return loss_function(observed, predicted, sample_weight=sample_weights)
else:
if sample_weights:
warnings.warn(
f"Loss function {loss_function.__name__} does not take sample weights. Calculating unweighted loss."
)
return loss_function(observed, predicted)

0 comments on commit 5fc4b91

Please sign in to comment.