Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace built-in measures with measures in StatisticalMeasures.jl #909

Merged
merged 14 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:
branches:
- master
- dev
- for-a-0-point-21-release
- for-a-0-point-22-release
- next-breaking-release
push:
branches:
Expand Down
19 changes: 16 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,48 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LearnAPI = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"

[extensions]
DefaultMeasuresExt = "StatisticalMeasures"

[compat]
CategoricalArrays = "0.9, 0.10"
CategoricalDistributions = "0.1"
ComputationalResources = "0.3"
Distributions = "0.25.3"
InvertedIndices = "1"
LossFunctions = "0.10"
LearnAPI = "0.1"
MLJModelInterface = "1.7"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
PrettyTables = "1, 2"
ProgressMeter = "1.7.1"
Reexport = "1.2"
ScientificTypes = "3"
StatisticalMeasures = "0.1.1"
StatisticalMeasuresBase = "0.1.1"
StatisticalTraits = "3.2"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1.0"
Expand All @@ -56,8 +68,9 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "Test", "TypedTables"]
test = ["DataFrames", "DecisionTree", "Distances", "Logging", "MultivariateStats", "NearestNeighbors", "StableRNGs", "StatisticalMeasures", "Test", "TypedTables"]
15 changes: 15 additions & 0 deletions ext/DefaultMeasuresExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module DefaultMeasuresExt

using MLJBase
import MLJBase:default_measure, ProbabilisticDetector, DeterministicDetector
using StatisticalMeasures
using StatisticalMeasures.ScientificTypesBase

default_measure(::Deterministic, ::Type{<:Union{Continuous,Count}}) = l2
default_measure(::Deterministic, ::Type{<:Finite}) = misclassification_rate
default_measure(::Probabilistic, ::Type{<:Union{Finite,Count}}) = log_loss
default_measure(::Probabilistic, ::Type{<:Continuous}) = log_loss
default_measure(::ProbabilisticDetector, ::Type{<:OrderedFactor{2}}) = area_under_curve
default_measure(::DeterministicDetector, ::Type{<:OrderedFactor{2}}) = balanced_accuracy

Check warning on line 13 in ext/DefaultMeasuresExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DefaultMeasuresExt.jl#L12-L13

Added lines #L12 - L13 were not covered by tests

end # module
90 changes: 14 additions & 76 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module MLJBase
# ===================================================================
# IMPORTS

using Reexport
import Base: ==, precision, getindex, setindex!
import Base.+, Base.*, Base./

Expand All @@ -16,7 +17,7 @@ for trait in StatisticalTraits.TRAITS
eval(:(import StatisticalTraits.$trait))
end

import Base.instances # considered a trait for measures
import LearnAPI
import StatisticalTraits.snakecase
import StatisticalTraits.info

Expand Down Expand Up @@ -47,7 +48,7 @@ end
###################
# Hack Block ends #
###################

import MLJModelInterface: ProbabilisticDetector, DeterministicDetector
import MLJModelInterface: fit, update, update_data, transform,
inverse_transform, fitted_params, predict, predict_mode,
predict_mean, predict_median, predict_joint,
Expand Down Expand Up @@ -78,7 +79,6 @@ using ProgressMeter
import .Threads

# Operations & extensions
import LossFunctions
import StatsBase
import StatsBase: fit!, mode, countmap
import Missings: levels
Expand All @@ -88,6 +88,9 @@ using CategoricalDistributions
import Distributions: pdf, logpdf, sampler
const Dist = Distributions

# Measures
import StatisticalMeasuresBase

# from Standard Library:
using Statistics, LinearAlgebra, Random, InteractiveUtils

Expand Down Expand Up @@ -127,57 +130,6 @@ const CatArrMissing{T,N} = ArrMissing{CategoricalValue{T},N}
const MMI = MLJModelInterface
const FI = MLJModelInterface.FullInterface

const MARGIN_LOSSES = [
:DWDMarginLoss,
:ExpLoss,
:L1HingeLoss,
:L2HingeLoss,
:L2MarginLoss,
:LogitMarginLoss,
:ModifiedHuberLoss,
:PerceptronLoss,
:SigmoidLoss,
:SmoothedL1HingeLoss,
:ZeroOneLoss
]

const DISTANCE_LOSSES = [
:HuberLoss,
:L1EpsilonInsLoss,
:L2EpsilonInsLoss,
:LPDistLoss,
:LogitDistLoss,
:PeriodicLoss,
:QuantileLoss
]

const WITH_PARAMETERS = [
:DWDMarginLoss,
:SmoothedL1HingeLoss,
:HuberLoss,
:L1EpsilonInsLoss,
:L2EpsilonInsLoss,
:LPDistLoss,
:QuantileLoss,
]

const MEASURE_TYPE_ALIASES = [
:FPR, :FNR, :TPR, :TNR,
:FDR, :PPV, :NPV, :Recall, :Specificity,
:MFPR, :MFNR, :MTPR, :MTNR,
:MFDR, :MPPV, :MNPV, :MulticlassRecall, :MulticlassSpecificity,
:MCR,
:MCC,
:BAC, :BACC,
:RMS, :RMSPV, :RMSL, :RMSLP, :RMSP,
:MAV, :MAE, :MAPE,
:RSQ, :LogCosh,
:CrossEntropy,
:AUC
]

const LOSS_FUNCTIONS = vcat(MARGIN_LOSSES, DISTANCE_LOSSES)

# ===================================================================
# Computational Resource
# default_resource allows to switch the mode of parallelization
Expand Down Expand Up @@ -224,15 +176,10 @@ include("data/data.jl")
include("data/datasets.jl")
include("data/datasets_synthetic.jl")

include("measures/measures.jl")
include("measures/measure_search.jl")
include("measures/doc_strings.jl")
include("default_measures.jl")

include("composition/models/stacking.jl")

# function on the right-hand side is defined in src/measures/meta_utilities.jl:
const MEASURE_TYPES_ALIASES_AND_INSTANCES = measures_for_export()

const EXTENDED_ABSTRACT_MODEL_TYPES = vcat(
MLJBase.MLJModelInterface.ABSTRACT_MODEL_SUBTYPES,
MLJBase.NETWORK_COMPOSITE_TYPES, # src/composition/models/network_composite_types.jl
Expand Down Expand Up @@ -355,28 +302,19 @@ export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV,
# -------------------------------------------------------------------
# exports from MLJBase specific to measures

# measure names:
for m in MEASURE_TYPES_ALIASES_AND_INSTANCES
:(export $m) |> eval
end

# measures/registry.jl:
export measures, metadata_measure

# measure/measures.jl (excluding traits):
export aggregate, default_measure, value, skipinvalid

# measures/probabilistic:
export roc_curve, roc

# measures/finite.jl (averaging modes for multiclass scores)
export no_avg, macro_avg, micro_avg

export default_measure

# -------------------------------------------------------------------
# re-export from Random, StatsBase, Statistics, Distributions,
# OrderedCollections, CategoricalArrays, InvertedIndices:
export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
levels, levels!, std, Not, support, logpdf, LittleDict

# for julia < 1.9
if !isdefined(Base, :get_extension)
include(joinpath("..","ext", "DefaultMeasuresExt.jl"))
@reexport using .DefaultMeasuresExt.StatisticalMeasures
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So StatisticalMeasures is only available in the public (re-exported) scope in Julia 1.9 or does the new extensions system from Julia automatically bring things to the public scope? (I'm not familiar with the new system yet)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using MLJBase without MLJ, then, in Julia 1.9 or higher, StatisticalMeasures must be explicitly imported to use measures that were previously part of MLJBase (the names are automatically re-exported by a pkg extension). For earlier of Julia versions, StatisticalMeasures.jl is a hard dependency (through a hack that is standard practice). If using MLJ, then all previous measures are still available, because StatisticalMeasures.jl will be a hard dep of MLJ.

end

end # module
40 changes: 24 additions & 16 deletions src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,21 +378,30 @@
function internal_stack_report(
stack::Stack{modelnames,},
verbosity::Int,
tt_pairs,
tt_pairs, # train_test_pairs
folds_evaluations...
) where modelnames

n_measures = length(stack.measures)
nfolds = length(tt_pairs)

# For each model we record the results mimicking the fields PerformanceEvaluation
test_fold_sizes = map(tt_pairs) do train_test_pair
test = last(train_test_pair)
length(test)
end

# weights to be used to aggregate per-fold measurements (averaging to 1):
fold_weights(mode) = nfolds .* test_fold_sizes ./ sum(test_fold_sizes)
fold_weights(::StatisticalMeasuresBase.Sum) = nothing

Check warning on line 395 in src/composition/models/stacking.jl

View check run for this annotation

Codecov / codecov/patch

src/composition/models/stacking.jl#L395

Added line #L395 was not covered by tests

# For each model we record the results mimicking the fields of PerformanceEvaluation
results = NamedTuple{modelnames}(
[(
measure = stack.measures,
measurement = Vector{Any}(undef, n_measures),
operation = _actual_operations(nothing, stack.measures, model, verbosity),
per_fold = [Vector{Any}(undef, nfolds) for _ in 1:n_measures],
per_observation = Vector{Union{Missing, Vector{Any}}}(missing, n_measures),
per_observation = [Vector{Vector{Any}}(undef, nfolds) for _ in 1:n_measures],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check, per_observation is now a Vector{Vector{Vector{Any}}}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, per_observation[i][j][k] is the the measurement for the kth observation, in the jth fold, for the ith measure. There no longer can be missings because measures that don't compute per-observation measurements, simply report copies of the aggregated instead.

fitted_params_per_fold = [],
report_per_fold = [],
train_test_pairs = tt_pairs
Expand All @@ -416,30 +425,29 @@
model_results.operation,
))
ypred = operation(mach, Xtest)
loss = measure(ypred, ytest)
# Update per_observation
if reports_each_observation(measure)
if model_results.per_observation[i] === missing
model_results.per_observation[i] = Vector{Any}(undef, nfolds)
end
model_results.per_observation[i][foldid] = loss
end
measurements = StatisticalMeasuresBase.measurements(measure, ypred, ytest)

# Update per observation:
model_results.per_observation[i][foldid] = measurements

# Update per_fold
model_results.per_fold[i][foldid] =
reports_each_observation(measure) ?
MLJBase.aggregate(loss, measure) : loss
model_results.per_fold[i][foldid] = measure(ypred, ytest)
end
index += 1
end
end

# Update measurement field by aggregation
# Update measurement field by aggregating per-fold measurements
for modelname in modelnames
for (i, measure) in enumerate(stack.measures)
model_results = results[modelname]
mode = StatisticalMeasuresBase.external_aggregation_mode(measure)
model_results.measurement[i] =
MLJBase.aggregate(model_results.per_fold[i], measure)
StatisticalMeasuresBase.aggregate(
model_results.per_fold[i];
mode,
weights=fold_weights(mode),
)
end
end

Expand Down
23 changes: 23 additions & 0 deletions src/default_measures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# # DEFAULT MEASURES

"""
default_measure(model)

Return a measure that should work with `model`, or return `nothing` if none can be
reliably inferred.

For Julia 1.9 and higher, `nothing` is returned, unless StatisticalMeasures.jl is
loaded.

# New implementations

This method dispatches `default_measure(model, observation_scitype)`, which has
`nothing` as the fallback return value. Extend `default_measure` by overloading this
version of the method. See for example the MLJBase.jl package extension,
DefaultMeausuresExt.jl.

"""
default_measure(m) = nothing
default_measure(m::Union{Supervised,Annotator}) =
default_measure(m, nonmissingtype(guess_model_target_observation_scitype(m)))
default_measure(m, S) = nothing
Loading