diff --git a/Project.toml b/Project.toml index 289790b..35cb0c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJTuning" uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" authors = ["Anthony D. Blaom "] -version = "0.8.4" +version = "0.8.5" [deps] ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" @@ -18,7 +18,7 @@ StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" ComputationalResources = "0.3" Distributions = "0.22,0.23,0.24, 0.25" LatinHypercubeSampling = "1.7.2" -MLJBase = "1" +MLJBase = "1.3" ProgressMeter = "1.7.1" RecipesBase = "0.8,0.9,1" StatisticalMeasuresBase = "0.1.1" diff --git a/src/tuned_models.jl b/src/tuned_models.jl index ed236cd..2eb0c97 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -50,6 +50,7 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter acceleration_resampling::AbstractResource check_measure::Bool cache::Bool + compact_history::Bool end mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Probabilistic @@ -69,6 +70,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba acceleration_resampling::AbstractResource check_measure::Bool cache::Bool + compact_history::Bool end const EitherTunedModel{T,M} = @@ -176,6 +178,15 @@ key | value plus other key/value pairs specific to the `tuning` strategy. +Each element of `history` is a property-accessible object with these properties: + +key | value +--------------------|-------------------------------------------------- +`measure` | vector of measures (metrics) +`measurement` | vector of measurements, one per measure +`per_fold` | vector of vectors of unaggregated per-fold measurements +`evaluation` | full `PerformanceEvaluation`/`CompactPerformaceEvaluation` object + ### Complete list of key-word options - `model`: `Supervised` model prototype that is cloned and mutated to @@ -240,27 +251,35 @@ plus other key/value pairs specific to the `tuning` strategy. user-suplied data; set to `false` to conserve memory. Speed gains likely limited to the case `resampling isa Holdout`. +- `compact_history=true`: whether to write `CompactPerformanceEvaluation`](@ref) or + regular [`PerformanceEvaluation`](@ref) objects to the history (accessed via the + `:evaluation` key); the compact form excludes some fields to conserve memory. + """ -function TunedModel(args...; model=nothing, - models=nothing, - tuning=nothing, - resampling=MLJBase.Holdout(), - measures=nothing, - measure=measures, - weights=nothing, - class_weights=nothing, - operations=nothing, - operation=operations, - ranges=nothing, - range=ranges, - selection_heuristic=NaiveSelection(), - train_best=true, - repeats=1, - n=nothing, - acceleration=default_resource(), - acceleration_resampling=CPU1(), - check_measure=true, - cache=true) +function TunedModel( + args...; + model=nothing, + models=nothing, + tuning=nothing, + resampling=MLJBase.Holdout(), + measures=nothing, + measure=measures, + weights=nothing, + class_weights=nothing, + operations=nothing, + operation=operations, + ranges=nothing, + range=ranges, + selection_heuristic=NaiveSelection(), + train_best=true, + repeats=1, + n=nothing, + acceleration=default_resource(), + acceleration_resampling=CPU1(), + check_measure=true, + cache=true, + compact_history=true, + ) # user can specify model as argument instead of kwarg: length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS) @@ -339,7 +358,8 @@ function TunedModel(args...; model=nothing, acceleration, acceleration_resampling, check_measure, - cache + cache, + compact_history, ) if M <: DeterministicTypes @@ -582,9 +602,10 @@ function assemble_events!(metamodels, check_measure = resampling_machine.model.check_measure, repeats = resampling_machine.model.repeats, acceleration = resampling_machine.model.acceleration, - cache = resampling_machine.model.cache), - resampling_machine.args...; cache=false) for - _ in 2:length(partitions)]...] + cache = resampling_machine.model.cache, + compact = resampling_machine.model.compact + ), resampling_machine.args...; cache=false) for + _ in 2:length(partitions)]...] @sync for (i, parts) in enumerate(partitions) Threads.@spawn begin @@ -736,21 +757,23 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M}, # instantiate resampler (`model` to be replaced with mutated # clones during iteration below): - resampler = Resampler(model=model, - resampling = deepcopy(tuned_model.resampling), - measure = tuned_model.measure, - weights = tuned_model.weights, - class_weights = tuned_model.class_weights, - operation = tuned_model.operation, - check_measure = tuned_model.check_measure, - repeats = tuned_model.repeats, - acceleration = tuned_model.acceleration_resampling, - cache = tuned_model.cache) + resampler = Resampler( + model=model, + resampling = deepcopy(tuned_model.resampling), + measure = tuned_model.measure, + weights = tuned_model.weights, + class_weights = tuned_model.class_weights, + operation = tuned_model.operation, + check_measure = tuned_model.check_measure, + repeats = tuned_model.repeats, + acceleration = tuned_model.acceleration_resampling, + cache = tuned_model.cache, + compact = tuned_model.compact_history, + ) resampling_machine = machine(resampler, data...; cache=false) history, state = build!(nothing, n, tuning, model, model_buffer, state, verbosity, acceleration, resampling_machine) - return finalize( tuned_model, model_buffer, @@ -867,9 +890,9 @@ function MLJBase.reports_feature_importances(model::EitherTunedModel) end # This is needed in some cases (e.g tuning a `Pipeline`) function MLJBase.feature_importances(::EitherTunedModel, fitresult, report) - # fitresult here is a machine created using the best_model obtained + # fitresult here is a machine created using the best_model obtained # from the tuning process. - # The line below will return `nothing` when the model being tuned doesn't + # The line below will return `nothing` when the model being tuned doesn't # support feature_importances. return MLJBase.feature_importances(fitresult) end diff --git a/test/tuned_models.jl b/test/tuned_models.jl index d56ecbc..9826dae 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -13,7 +13,7 @@ using Random Random.seed!(1234*myid()) using .TestUtilities -begin +begin N = 30 x1 = rand(N); x2 = rand(N); @@ -157,14 +157,14 @@ end @testset_accelerated "Feature Importances" accel begin # the DecisionTreeClassifier in /test/_models/ supports feature importances. - tm0 = TunedModel( - model = trees[1], - measure = rms, - tuning = Grid(), - resampling = CV(nfolds = 5), - range = range( - trees[1], :max_depth, values = 1:10 - ) + tm0 = TunedModel( + model = trees[1], + measure = rms, + tuning = Grid(), + resampling = CV(nfolds = 5), + range = range( + trees[1], :max_depth, values = 1:10 + ) ) @test reports_feature_importances(typeof(tm0)) tm = TunedModel( @@ -435,7 +435,7 @@ end model = DecisionTreeClassifier() tmodel = TunedModel(models=[model,]) mach = machine(tmodel, X, y) - @test mach isa Machine{<:Any,false} + @test !MLJBase.caches_data(mach) fit!(mach, verbosity=-1) @test !isdefined(mach, :data) MLJBase.Tables.istable(mach.cache[end].fitresult.machine.data[1]) @@ -490,7 +490,7 @@ end @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) end -@testset_accelerated "full evaluation object" accel begin +@testset_accelerated "evaluation object" accel begin X, y = make_regression(100, 2) dcr = DeterministicConstantRegressor() @@ -504,10 +504,24 @@ end fit!(homach, verbosity=0); horep = report(homach) evaluations = getproperty.(horep.history, :evaluation) + @test first(evaluations) isa MLJBase.CompactPerformanceEvaluation measurements = getproperty.(evaluations, :measurement) models = getproperty.(evaluations, :model) @test all(==(measurements[1]), measurements) @test all(==(dcr), models) + + homodel = TunedModel( + models=fill(dcr, 10), + resampling=Holdout(rng=StableRNG(1234)), + acceleration_resampling=accel, + measure=mae, + compact_history=false, + ) + homach = machine(homodel, X, y) + fit!(homach, verbosity=0); + horep = report(homach) + evaluations = getproperty.(horep.history, :evaluation) + @test first(evaluations) isa MLJBase.PerformanceEvaluation end true