-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace built-in measures with measures in StatisticalMeasures.jl #909
Changes from 12 commits
2a7941f
1a4e5e1
4c54571
b1bfab1
da710ad
2692ca9
1dde0d9
a3fca52
91a6f54
e2d15dd
7ae73a5
588e777
e5dfa91
0549150
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
module DefaultMeasuresExt | ||
|
||
using MLJBase | ||
import MLJBase:default_measure, ProbabilisticDetector, DeterministicDetector | ||
using StatisticalMeasures | ||
using StatisticalMeasures.ScientificTypesBase | ||
|
||
default_measure(::Deterministic, ::Type{<:Union{Continuous,Count}}) = l2 | ||
default_measure(::Deterministic, ::Type{<:Finite}) = misclassification_rate | ||
default_measure(::Probabilistic, ::Type{<:Union{Finite,Count}}) = log_loss | ||
default_measure(::Probabilistic, ::Type{<:Continuous}) = log_loss | ||
default_measure(::ProbabilisticDetector, ::Type{<:OrderedFactor{2}}) = area_under_curve | ||
default_measure(::DeterministicDetector, ::Type{<:OrderedFactor{2}}) = balanced_accuracy | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -378,21 +378,30 @@ | |
function internal_stack_report( | ||
stack::Stack{modelnames,}, | ||
verbosity::Int, | ||
tt_pairs, | ||
tt_pairs, # train_test_pairs | ||
folds_evaluations... | ||
) where modelnames | ||
|
||
n_measures = length(stack.measures) | ||
nfolds = length(tt_pairs) | ||
|
||
# For each model we record the results mimicking the fields PerformanceEvaluation | ||
test_fold_sizes = map(tt_pairs) do train_test_pair | ||
test = last(train_test_pair) | ||
length(test) | ||
end | ||
|
||
# weights to be used to aggregate per-fold measurements (averaging to 1): | ||
fold_weights(mode) = nfolds .* test_fold_sizes ./ sum(test_fold_sizes) | ||
fold_weights(::StatisticalMeasuresBase.Sum) = nothing | ||
|
||
# For each model we record the results mimicking the fields of PerformanceEvaluation | ||
results = NamedTuple{modelnames}( | ||
[( | ||
measure = stack.measures, | ||
measurement = Vector{Any}(undef, n_measures), | ||
operation = _actual_operations(nothing, stack.measures, model, verbosity), | ||
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures], | ||
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures), | ||
per_observation = [Vector{Vector{Any}}(undef, nfolds) for _ in 1:n_measures], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to double check, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, |
||
fitted_params_per_fold = [], | ||
report_per_fold = [], | ||
train_test_pairs = tt_pairs | ||
|
@@ -416,30 +425,29 @@ | |
model_results.operation, | ||
)) | ||
ypred = operation(mach, Xtest) | ||
loss = measure(ypred, ytest) | ||
# Update per_observation | ||
if reports_each_observation(measure) | ||
if model_results.per_observation[i] === missing | ||
model_results.per_observation[i] = Vector{Any}(undef, nfolds) | ||
end | ||
model_results.per_observation[i][foldid] = loss | ||
end | ||
measurements = StatisticalMeasuresBase.measurements(measure, ypred, ytest) | ||
|
||
# Update per observation: | ||
model_results.per_observation[i][foldid] = measurements | ||
|
||
# Update per_fold | ||
model_results.per_fold[i][foldid] = | ||
reports_each_observation(measure) ? | ||
MLJBase.aggregate(loss, measure) : loss | ||
model_results.per_fold[i][foldid] = measure(ypred, ytest) | ||
end | ||
index += 1 | ||
end | ||
end | ||
|
||
# Update measurement field by aggregation | ||
# Update measurement field by aggregating per-fold measurements | ||
for modelname in modelnames | ||
for (i, measure) in enumerate(stack.measures) | ||
model_results = results[modelname] | ||
mode = StatisticalMeasuresBase.external_aggregation_mode(measure) | ||
model_results.measurement[i] = | ||
MLJBase.aggregate(model_results.per_fold[i], measure) | ||
StatisticalMeasuresBase.aggregate( | ||
model_results.per_fold[i]; | ||
mode, | ||
weights=fold_weights(mode), | ||
) | ||
end | ||
end | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# # DEFAULT MEASURES | ||
|
||
""" | ||
default_measure(model) | ||
|
||
Return a measure that should work with `model`, or return `nothing` if none can be | ||
reliably inferred. | ||
|
||
For Julia 1.9 and higher, `nothing` is returned, unless StatisticalMeasures.jl is | ||
loaded. | ||
|
||
# New implementations | ||
|
||
This method dispatches `default_measure(model, observation_scitype)`, which has | ||
`nothing` as the fallback return value. Extend `default_measure` by overloading this | ||
version of the method. See for example the MLJBase.jl package extension, | ||
DefaultMeausuresExt.jl. | ||
|
||
""" | ||
default_measure(m) = nothing | ||
default_measure(m::Union{Supervised,Annotator}) = | ||
default_measure(m, nonmissingtype(guess_model_target_observation_scitype(m))) | ||
default_measure(m, S) = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So
StatisticalMeasures
is only available in the public (re-exported) scope in Julia 1.9 or does the new extensions system from Julia automatically bring things to the public scope? (I'm not familiar with the new system yet)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
using MLJBase
without MLJ, then, in Julia 1.9 or higher,StatisticalMeasures
must be explicitly imported to use measures that were previously part of MLJBase (the names are automatically re-exported by a pkg extension). For earlier of Julia versions, StatisticalMeasures.jl is a hard dependency (through a hack that is standard practice). Ifusing MLJ
, then all previous measures are still available, because StatisticalMeasures.jl will be a hard dep of MLJ.