From 6d8a9c12b26f139643ab5a00a176a39072e451cf Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 10 May 2022 14:16:13 +1200 Subject: [PATCH 1/7] add class_weight support to close #134 --- src/tuned_models.jl | 68 ++++++++++++++++++++++++++++++-------------- test/tuned_models.jl | 44 ++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 22 deletions(-) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index a9102cc..15249be 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -31,6 +31,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter resampling # resampling strategy measure weights::Union{Nothing,Vector{<:Real}} + class_weights::Union{Nothing,Dict} operation range selection_heuristic @@ -49,6 +50,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba resampling # resampling strategy measure weights::Union{Nothing,AbstractVector{<:Real}} + class_weights::Union{Nothing,Dict} operation range selection_heuristic @@ -114,6 +116,8 @@ Calling `fit!(mach)` on a machine `mach=machine(tuned_model, X, y)` or internal machine. The final train can be supressed by setting `train_best=false`. +### Search space + The `range` objects supported depend on the `tuning` strategy specified. Query the `strategy` docstring for details. To optimize over an explicit list `v` of models of the same type, use @@ -124,28 +128,26 @@ then `MLJTuning.default_n(tuning, range)` is used. When `n` is increased and `fit!(mach)` called again, the old search history is re-instated and the search continues where it left off. -If `measure` supports weights (`supports_weights(measure) == true`) -then any `weights` specified will be passed to the measure. If more -than one `measure` is specified, then only the first is optimized -(unless `strategy` is multi-objective) but the performance against -every measure specified will be computed and reported in -`report(mach).best_performance` and other relevant attributes of the -generated report. +### Measures (metrics) -Specify `repeats > 1` for repeated resampling per model -evaluation. See [`evaluate!`](@ref) options for details. +If more than one `measure` is specified, then only the first is +optimized (unless `strategy` is multi-objective) but the performance +against every measure specified will be computed and reported in +`report(mach).best_performance` and other relevant attributes of the +generated report. Options exist to pass per-observation weights or +class weights to measures; see below. -*Important.* If a custom `measure` is used, and the measure is -a score, rather than a loss, be sure to check that -`MLJ.orientation(measure) == :score` to ensure maximization of the +*Important.* If a custom measure, `my_measure` is used, and the +measure is a score, rather than a loss, be sure to check that +`MLJ.orientation(my_measure) == :score` to ensure maximization of the measure, rather than minimization. Override an incorrect value with -`MLJ.orientation(::typeof(measure)) = :score`. +`MLJ.orientation(::typeof(my_measure)) = :score`. + +### Accessing the fitted parameters and other training (tuning) outcomes A Plots.jl plot of performance estimates is returned by `plot(mach)` or `heatmap(mach)`. -### Accessing the fitted parameters and other training (tuning) outcomes - Once a tuning machine `mach` has bee trained as above, then `fitted_params(mach)` has these keys/values: @@ -165,7 +167,7 @@ key | value plus other key/value pairs specific to the `tuning` strategy. -### Summary of key-word arguments +### Complete list of key-word options - `model`: `Supervised` model prototype that is cloned and mutated to generate models for evaluation @@ -185,11 +187,15 @@ plus other key/value pairs specific to the `tuning` strategy. evaluations; only the first used in optimization (unless the strategy is multi-objective) but all reported to the history -- `weights`: sample weights to be passed the measure(s) in performance - evaluations, if supported. +- `weights`: per-observation weights to be passed the measure(s) in performance + evaluations, where supported. Check support with `supports_weights(measure)`. + +- `class_weights`: class weights to be passed the measure(s) in + performance evaluations, where supported. Check support with + `supports_class_weights(measure)`. - `repeats=1`: for generating train/test sets multiple times in - resampling; see [`evaluate!`](@ref) for details + resampling ("Monte Carlo" resampling); see [`evaluate!`](@ref) for details - `operation`/`operations` - One of $(MLJBase.PREDICT_OPERATIONS_STRING), or a vector of these of the @@ -233,6 +239,7 @@ function TunedModel(; model=nothing, measures=nothing, measure=measures, weights=nothing, + class_weights=nothing, operations=nothing, operation=operations, ranges=nothing, @@ -295,9 +302,24 @@ function TunedModel(; model=nothing, # get the tuning type parameter: T = typeof(tuning) - args = (model, tuning, resampling, measure, weights, operation, range, - selection_heuristic, train_best, repeats, n, acceleration, acceleration_resampling, - check_measure, cache) + args = ( + model, + tuning, + resampling, + measure, + weights, + class_weights, + operation, + range, + selection_heuristic, + train_best, + repeats, + n, + acceleration, + acceleration_resampling, + check_measure, + cache + ) if M <: DeterministicTypes tuned_model = DeterministicTunedModel{T,M}(args...) @@ -531,6 +553,7 @@ function assemble_events!(metamodels, resampling = resampling_machine.model.resampling, measure = resampling_machine.model.measure, weights = resampling_machine.model.weights, + class_weights = resampling_machine.model.class_weights, operation = resampling_machine.model.operation, check_measure = resampling_machine.model.check_measure, repeats = resampling_machine.model.repeats, @@ -693,6 +716,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M}, resampling = deepcopy(tuned_model.resampling), measure = tuned_model.measure, weights = tuned_model.weights, + class_weights = tuned_model.class_weights, operation = tuned_model.operation, check_measure = tuned_model.check_measure, repeats = tuned_model.repeats, diff --git a/test/tuned_models.jl b/test/tuned_models.jl index 470ae4a..8adb74a 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -341,4 +341,48 @@ end end +@testset_accelerated "weights and class_weights are being passed" accel begin + # we'll be tuning using 50/50 holdout + X = (x=fill(1.0, 6),) + y = coerce(["a", "a", "b", "a", "a", "b"], OrderedFactor) + w = [1.0, 1.0, 100.0, 1.0, 1.0, 100.0] + class_w = Dict("a" => 2.0, "b" => 100.0) + + model = DecisionTreeClassifier() + + # the first supports weights, the second class weights: + ms=[MisclassificationRate(), MulticlassFScore()] + + resampling=Holdout(fraction_train=0.5) + + # without weights: + tmodel = TunedModel( + resampling=resampling, + models=fill(model, 5), + measures=ms, + acceleration=accel + ) + mach = machine(tmodel, X, y) + fit!(mach, verbosity=0) + measurement = report(mach).best_history_entry.measurement + e = evaluate(model, X, y, measures=ms, resampling=resampling) + @test measurement == e.measurement + + # with weights: + tmodel.weights = w + tmodel.class_weights = class_w + fit!(mach, verbosity=0) + measurement_weighted = report(mach).best_history_entry.measurement + e_weighted = evaluate(model, X, y; + measures=ms, + resampling=resampling, + weights=w, + class_weights=class_w, + verbosity=-1) + @test measurement_weighted == e_weighted.measurement + + # check both measures are different when they are weighted: + @test !any(measurement .== measurement_weighted) +end + true From 5fc10967e28193ad1c91d1bf85c1b47c5a468bf7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 10 May 2022 16:12:15 +1200 Subject: [PATCH 2/7] suppress data caching at outer level to close #171 --- src/tuned_models.jl | 3 ++- test/tuned_models.jl | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index a9102cc..02ff9b2 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -64,7 +64,8 @@ end const EitherTunedModel{T,M} = Union{DeterministicTunedModel{T,M},ProbabilisticTunedModel{T,M}} -#todo update: +MLJBase.caches_data_by_default(::Type{<:EitherTunedModel}) = false + """ tuned_model = TunedModel(; model=, tuning=RandomSearch(), diff --git a/test/tuned_models.jl b/test/tuned_models.jl index 470ae4a..ff53a4c 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -341,4 +341,15 @@ end end +@testset "data caching at outer level suppressed" begin + X, y = make_blobs() + model = DecisionTreeClassifier() + tmodel = TunedModel(models=[model,]) + mach = machine(tmodel, X, y) + @test mach isa Machine{<:Any,false} + fit!(mach, verbosity=-1) + @test !isdefined(mach, :data) + MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1]) +end + true From 45a676edf75ae4a906858409a4071847c2ca55b0 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Mon, 23 May 2022 09:35:52 +1200 Subject: [PATCH 3/7] Update src/tuned_models.jl Co-authored-by: Venkateshprasad <32921645+ven-k@users.noreply.github.com> --- src/tuned_models.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 15249be..735997d 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -30,8 +30,8 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter tuning::T # tuning strategy resampling # resampling strategy measure - weights::Union{Nothing,Vector{<:Real}} - class_weights::Union{Nothing,Dict} + weights::Union{Nothing,AbstractVector{<:Real}} + class_weights::Union{Nothing,AbstractDict} operation range selection_heuristic From 71bd5f7ba98ece6c5868140d50c9e441b78e6852 Mon Sep 17 00:00:00 2001 From: "Anthony Blaom, PhD" Date: Mon, 23 May 2022 09:36:30 +1200 Subject: [PATCH 4/7] Dict -> AbstractDict Co-authored-by: Venkateshprasad <32921645+ven-k@users.noreply.github.com> --- src/tuned_models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 735997d..074ce73 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -50,7 +50,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba resampling # resampling strategy measure weights::Union{Nothing,AbstractVector{<:Real}} - class_weights::Union{Nothing,Dict} + class_weights::Union{Nothing,AbstractDict} operation range selection_heuristic From d009cb79ea633276e588db2d395cb852c2c99841 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 23 May 2022 09:38:17 +1200 Subject: [PATCH 5/7] overload supports_class_weights for TunedModel --- src/tuned_models.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 074ce73..991c15d 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -808,6 +808,8 @@ end MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M = MLJBase.supports_weights(M) +MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M = + MLJBase.supports_class_weights(M) MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) = "MLJTuning.ProbabilisticTunedModel" MLJBase.load_path(::Type{<:DeterministicTunedModel}) = From 0174e9382f64f6905b504df8d1fa375a3a249fee Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 24 May 2022 16:45:12 +1200 Subject: [PATCH 6/7] Allow user to specify model in `TunedModel` as arg instead of kwarg --- src/tuned_models.jl | 17 +++++++++++++++-- test/tuned_models.jl | 13 +++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 1b4936b..853c74d 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -21,6 +21,10 @@ const ERR_MODEL_TYPE = ArgumentError( "Only `Deterministic` and `Probabilistic` model types supported.") const INFO_MODEL_IGNORED = "`model` being ignored. Using `model=first(range)`. " +const ERR_TOO_MANY_ARGUMENTS = + ArgumentError("At most one non-keyword argument allowed. ") +warn_double_spec(arg, model) = + "Using `model=$arg`. Ignoring keyword specification `model=$model`. " const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector} const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector} @@ -233,7 +237,7 @@ plus other key/value pairs specific to the `tuning` strategy. likely limited to the case `resampling isa Holdout`. """ -function TunedModel(; model=nothing, +function TunedModel(args...; model=nothing, models=nothing, tuning=nothing, resampling=MLJBase.Holdout(), @@ -254,8 +258,17 @@ function TunedModel(; model=nothing, check_measure=true, cache=true) + # user can specify model as argument instead of kwarg: + length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS) + if length(args) === 1 + arg = first(args) + model === nothing || + @warn warn_double_spec(arg, model) + model =arg + end + # either `models` is specified and `tuning` is set to `Explicit`, - # or `models` is unspecified and tuning will fallback to `Grid()` + # or `models` is unspecified and tuning will fallback to `RandomSearch()` # unless it is itself specified: if models !== nothing if tuning === nothing diff --git a/test/tuned_models.jl b/test/tuned_models.jl index 24563da..c512d9a 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -35,7 +35,7 @@ r = [m(K) for K in 13:-1:2] @test_throws(MLJTuning.ERR_BOTH_DISALLOWED, TunedModel(model=first(r), models=r, tuning=Explicit(), measure=rms)) - tm = TunedModel(models=r, tuning=Explicit(), measure=rms) + tm = @test_logs TunedModel(models=r, tuning=Explicit(), measure=rms) @test tm.tuning isa Explicit && tm.range ==r && tm.model == first(r) @test input_scitype(tm) == Unknown @test TunedModel(models=r, measure=rms) == tm @@ -54,7 +54,16 @@ r = [m(K) for K in 13:-1:2] TunedModel(tuning=Explicit(), measure=rms)) @test_throws(MLJTuning.ERR_NEED_EXPLICIT, TunedModel(models=r, tuning=Grid())) - tm = TunedModel(model=first(r), range=r, measure=rms) + @test_logs TunedModel(first(r), range=r, measure=rms) + @test_logs( + (:warn, MLJTuning.warn_double_spec(first(r), last(r))), + TunedModel(first(r), model=last(r), range=r, measure=rms), + ) + @test_throws( + MLJTuning.ERR_TOO_MANY_ARGUMENTS, + TunedModel(first(r), last(r), range=r, measure=rms), + ) + tm = @test_logs TunedModel(model=first(r), range=r, measure=rms) @test tm.tuning isa RandomSearch @test input_scitype(tm) == Table(Continuous) end From 72a9f8d7074f9dba55ce9c8bf34db605a5d81ca3 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 24 May 2022 16:57:13 +1200 Subject: [PATCH 7/7] bump 0.7.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 784451c..c51915d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJTuning" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" authors = ["Anthony D. Blaom "] -version = "0.7.0" +version = "0.7.1" [deps] ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"