Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace built-in measures with measures in StatisticalMeasures.jl #909

Merged
merged 14 commits into from
Sep 21, 2023
Merged
19 changes: 16 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,49 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"

[extensions]
DefaultMeasuresExt = "StatisticalMeasures"

[compat]
CategoricalArrays = "0.9, 0.10"
CategoricalDistributions = "0.1"
ComputationalResources = "0.3"
DelimitedFiles = "1"
Distributions = "0.25.3"
InvertedIndices = "1"
LossFunctions = "0.11"
LearnAPI = "0.1"
MLJModelInterface = "1.7"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
PrettyTables = "1, 2"
ProgressMeter = "1.7.1"
Reexport = "1.2"
ScientificTypes = "3"
StatisticalMeasures = "0.1.1"
StatisticalMeasuresBase = "0.1.1"
StatisticalTraits = "3.2"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1.0"
Expand All @@ -57,8 +69,9 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "StatisticalMeasures", "Test", "TypedTables"]
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ repository provides core functionality for MLJ, including:

- basic utilities for **manipulating datasets** and for **synthesizing datasets** (src/data)

- a [small interface](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/#Custom-resampling-strategies-1) for **resampling strategies** and implementations, including `CV()`, `StratifiedCV` and `Holdout` (src/resampling.jl)
- a [small
interface](https://alan-turing-institute.github.io/MLJ.jl/dev/evaluating_model_performance/#Custom-resampling-strategies-1)
for **resampling strategies** and implementations, including `CV()`, `StratifiedCV` and
`Holdout` (src/resampling.jl). Actual performance evaluation measures (aka metrics), which previously
were provided by MLJBase.jl, now live in [StatisticalMeasures.jl](https://juliaai.github.io/StatisticalMeasures.jl/dev/).

- methods for **performance evaluation**, based on those resampling strategies (src/resampling.jl)

- **one-dimensional hyperparameter range types**, constructors and
associated methods, for use with
[MLJTuning](https://github.com/JuliaAI/MLJTuning.jl) (src/hyperparam)

- a [small
interface](https://alan-turing-institute.github.io/MLJ.jl/dev/performance_measures/#Traits-and-custom-measures-1)
for **performance measures** (losses and scores), implementation of about 60 such measures, including integration of the
[LossFunctions.jl](https://github.com/JuliaML/LossFunctions.jl)
library (src/measures). To be migrated into separate package in the near future.

15 changes: 15 additions & 0 deletions ext/DefaultMeasuresExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module DefaultMeasuresExt

using MLJBase
import MLJBase:default_measure, ProbabilisticDetector, DeterministicDetector
using StatisticalMeasures
using StatisticalMeasures.ScientificTypesBase

default_measure(::Deterministic, ::Type{<:Union{Continuous,Count}}) = l2
default_measure(::Deterministic, ::Type{<:Finite}) = misclassification_rate
default_measure(::Probabilistic, ::Type{<:Union{Finite,Count}}) = log_loss
default_measure(::Probabilistic, ::Type{<:Continuous}) = log_loss
default_measure(::ProbabilisticDetector, ::Type{<:OrderedFactor{2}}) = area_under_curve
default_measure(::DeterministicDetector, ::Type{<:OrderedFactor{2}}) = balanced_accuracy

Check warning on line 13 in ext/DefaultMeasuresExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DefaultMeasuresExt.jl#L12-L13

Added lines #L12 - L13 were not covered by tests

end # module
93 changes: 15 additions & 78 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module MLJBase
module MLJBase

# ===================================================================
# IMPORTS

using Reexport
import Base: ==, precision, getindex, setindex!
import Base.+, Base.*, Base./

Expand All @@ -16,7 +17,7 @@ for trait in StatisticalTraits.TRAITS
eval(:(import StatisticalTraits.$trait))
end

import Base.instances # considered a trait for measures
import LearnAPI
import StatisticalTraits.snakecase
import StatisticalTraits.info

Expand Down Expand Up @@ -47,7 +48,7 @@ end
###################
# Hack Block ends #
###################

import MLJModelInterface: ProbabilisticDetector, DeterministicDetector
import MLJModelInterface: fit, update, update_data, transform,
inverse_transform, fitted_params, predict, predict_mode,
predict_mean, predict_median, predict_joint,
Expand Down Expand Up @@ -78,8 +79,6 @@ using ProgressMeter
import .Threads

# Operations & extensions
import LossFunctions
import LossFunctions.Traits
import StatsBase
import StatsBase: fit!, mode, countmap
import Missings: levels
Expand All @@ -89,6 +88,9 @@ using CategoricalDistributions
import Distributions: pdf, logpdf, sampler
const Dist = Distributions

# Measures
import StatisticalMeasuresBase

# from Standard Library:
using Statistics, LinearAlgebra, Random, InteractiveUtils

Expand Down Expand Up @@ -128,57 +130,6 @@ const CatArrMissing{T,N} = ArrMissing{CategoricalValue{T},N}
const MMI = MLJModelInterface
const FI = MLJModelInterface.FullInterface

const MARGIN_LOSSES = [
:DWDMarginLoss,
:ExpLoss,
:L1HingeLoss,
:L2HingeLoss,
:L2MarginLoss,
:LogitMarginLoss,
:ModifiedHuberLoss,
:PerceptronLoss,
:SigmoidLoss,
:SmoothedL1HingeLoss,
:ZeroOneLoss
]

const DISTANCE_LOSSES = [
:HuberLoss,
:L1EpsilonInsLoss,
:L2EpsilonInsLoss,
:LPDistLoss,
:LogitDistLoss,
:PeriodicLoss,
:QuantileLoss
]

const WITH_PARAMETERS = [
:DWDMarginLoss,
:SmoothedL1HingeLoss,
:HuberLoss,
:L1EpsilonInsLoss,
:L2EpsilonInsLoss,
:LPDistLoss,
:QuantileLoss,
]

const MEASURE_TYPE_ALIASES = [
:FPR, :FNR, :TPR, :TNR,
:FDR, :PPV, :NPV, :Recall, :Specificity,
:MFPR, :MFNR, :MTPR, :MTNR,
:MFDR, :MPPV, :MNPV, :MulticlassRecall, :MulticlassSpecificity,
:MCR,
:MCC,
:BAC, :BACC,
:RMS, :RMSPV, :RMSL, :RMSLP, :RMSP,
:MAV, :MAE, :MAPE,
:RSQ, :LogCosh,
:CrossEntropy,
:AUC
]

const LOSS_FUNCTIONS = vcat(MARGIN_LOSSES, DISTANCE_LOSSES)

# ===================================================================
# Computational Resource
# default_resource allows to switch the mode of parallelization
Expand Down Expand Up @@ -225,15 +176,10 @@ include("data/data.jl")
include("data/datasets.jl")
include("data/datasets_synthetic.jl")

include("measures/measures.jl")
include("measures/measure_search.jl")
include("measures/doc_strings.jl")
include("default_measures.jl")

include("composition/models/stacking.jl")

# function on the right-hand side is defined in src/measures/meta_utilities.jl:
const MEASURE_TYPES_ALIASES_AND_INSTANCES = measures_for_export()

const EXTENDED_ABSTRACT_MODEL_TYPES = vcat(
MLJBase.MLJModelInterface.ABSTRACT_MODEL_SUBTYPES,
MLJBase.NETWORK_COMPOSITE_TYPES, # src/composition/models/network_composite_types.jl
Expand Down Expand Up @@ -357,28 +303,19 @@ export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV,
# -------------------------------------------------------------------
# exports from MLJBase specific to measures

# measure names:
for m in MEASURE_TYPES_ALIASES_AND_INSTANCES
:(export $m) |> eval
end

# measures/registry.jl:
export measures, metadata_measure

# measure/measures.jl (excluding traits):
export aggregate, default_measure, value, skipinvalid

# measures/probabilistic:
export roc_curve, roc

# measures/finite.jl (averaging modes for multiclass scores)
export no_avg, macro_avg, micro_avg

export default_measure

# -------------------------------------------------------------------
# re-export from Random, StatsBase, Statistics, Distributions,
# OrderedCollections, CategoricalArrays, InvertedIndices:
export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
levels, levels!, std, Not, support, logpdf, LittleDict

# for julia < 1.9
if !isdefined(Base, :get_extension)
include(joinpath("..","ext", "DefaultMeasuresExt.jl"))
@reexport using .DefaultMeasuresExt.StatisticalMeasures
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So StatisticalMeasures is only available in the public (re-exported) scope in Julia 1.9 or does the new extensions system from Julia automatically bring things to the public scope? (I'm not familiar with the new system yet)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using MLJBase without MLJ, then, in Julia 1.9 or higher, StatisticalMeasures must be explicitly imported to use measures that were previously part of MLJBase (the names are automatically re-exported by a pkg extension). For earlier of Julia versions, StatisticalMeasures.jl is a hard dependency (through a hack that is standard practice). If using MLJ, then all previous measures are still available, because StatisticalMeasures.jl will be a hard dep of MLJ.


end # module
40 changes: 24 additions & 16 deletions src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,22 +378,31 @@
function internal_stack_report(
stack::Stack{modelnames,},
verbosity::Int,
tt_pairs,
tt_pairs, # train_test_pairs
folds_evaluations...
) where modelnames

n_measures = length(stack.measures)
nfolds = length(tt_pairs)

# For each model we record the results mimicking the fields PerformanceEvaluation
test_fold_sizes = map(tt_pairs) do train_test_pair
test = last(train_test_pair)
length(test)
end

# weights to be used to aggregate per-fold measurements (averaging to 1):
fold_weights(mode) = nfolds .* test_fold_sizes ./ sum(test_fold_sizes)
fold_weights(::StatisticalMeasuresBase.Sum) = nothing

Check warning on line 395 in src/composition/models/stacking.jl

View check run for this annotation

Codecov / codecov/patch

src/composition/models/stacking.jl#L395

Added line #L395 was not covered by tests

# For each model we record the results mimicking the fields of PerformanceEvaluation
results = NamedTuple{modelnames}(
[(
model = model,
measure = stack.measures,
measurement = Vector{Any}(undef, n_measures),
operation = _actual_operations(nothing, stack.measures, model, verbosity),
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures),
per_observation = [Vector{Vector{Any}}(undef, nfolds) for _ in 1:n_measures],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check, per_observation is now a Vector{Vector{Vector{Any}}}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per_observation[i][j][k] is the the measurement for the kth observation, in the jth fold, for the ith measure. There no longer can be missings because measures that don't compute per-observation measurements, simply report copies of the aggregated instead.

fitted_params_per_fold = [],
report_per_fold = [],
train_test_pairs = tt_pairs,
Expand All @@ -419,30 +428,29 @@
model_results.operation,
))
ypred = operation(mach, Xtest)
loss = measure(ypred, ytest)
# Update per_observation
if reports_each_observation(measure)
if model_results.per_observation[i] === missing
model_results.per_observation[i] = Vector{Any}(undef, nfolds)
end
model_results.per_observation[i][foldid] = loss
end
measurements = StatisticalMeasuresBase.measurements(measure, ypred, ytest)

# Update per observation:
model_results.per_observation[i][foldid] = measurements

# Update per_fold
model_results.per_fold[i][foldid] =
reports_each_observation(measure) ?
MLJBase.aggregate(loss, measure) : loss
model_results.per_fold[i][foldid] = measure(ypred, ytest)
end
index += 1
end
end

# Update measurement field by aggregation
# Update measurement field by aggregating per-fold measurements
for modelname in modelnames
for (i, measure) in enumerate(stack.measures)
model_results = results[modelname]
mode = StatisticalMeasuresBase.external_aggregation_mode(measure)
model_results.measurement[i] =
MLJBase.aggregate(model_results.per_fold[i], measure)
StatisticalMeasuresBase.aggregate(
model_results.per_fold[i];
mode,
weights=fold_weights(mode),
)
end
end

Expand Down
23 changes: 23 additions & 0 deletions src/default_measures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# # DEFAULT MEASURES

"""
default_measure(model)

Return a measure that should work with `model`, or return `nothing` if none can be
reliably inferred.

For Julia 1.9 and higher, `nothing` is returned, unless StatisticalMeasures.jl is
loaded.

# New implementations

This method dispatches `default_measure(model, observation_scitype)`, which has
`nothing` as the fallback return value. Extend `default_measure` by overloading this
version of the method. See for example the MLJBase.jl package extension,
DefaultMeausuresExt.jl.

"""
default_measure(m) = nothing
default_measure(m::Union{Supervised,Annotator}) =
default_measure(m, nonmissingtype(guess_model_target_observation_scitype(m)))
default_measure(m, S) = nothing
Loading
Loading