diff --git a/src/tuned_models.jl b/src/tuned_models.jl index 1b4936b..853c74d 100644 --- a/src/tuned_models.jl +++ b/src/tuned_models.jl @@ -21,6 +21,10 @@ const ERR_MODEL_TYPE = ArgumentError( "Only `Deterministic` and `Probabilistic` model types supported.") const INFO_MODEL_IGNORED = "`model` being ignored. Using `model=first(range)`. " +const ERR_TOO_MANY_ARGUMENTS = + ArgumentError("At most one non-keyword argument allowed. ") +warn_double_spec(arg, model) = + "Using `model=$arg`. Ignoring keyword specification `model=$model`. " const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector} const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector} @@ -233,7 +237,7 @@ plus other key/value pairs specific to the `tuning` strategy. likely limited to the case `resampling isa Holdout`. """ -function TunedModel(; model=nothing, +function TunedModel(args...; model=nothing, models=nothing, tuning=nothing, resampling=MLJBase.Holdout(), @@ -254,8 +258,17 @@ function TunedModel(; model=nothing, check_measure=true, cache=true) + # user can specify model as argument instead of kwarg: + length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS) + if length(args) === 1 + arg = first(args) + model === nothing || + @warn warn_double_spec(arg, model) + model =arg + end + # either `models` is specified and `tuning` is set to `Explicit`, - # or `models` is unspecified and tuning will fallback to `Grid()` + # or `models` is unspecified and tuning will fallback to `RandomSearch()` # unless it is itself specified: if models !== nothing if tuning === nothing diff --git a/test/tuned_models.jl b/test/tuned_models.jl index 24563da..c512d9a 100644 --- a/test/tuned_models.jl +++ b/test/tuned_models.jl @@ -35,7 +35,7 @@ r = [m(K) for K in 13:-1:2] @test_throws(MLJTuning.ERR_BOTH_DISALLOWED, TunedModel(model=first(r), models=r, tuning=Explicit(), measure=rms)) - tm = TunedModel(models=r, tuning=Explicit(), measure=rms) + tm = @test_logs TunedModel(models=r, tuning=Explicit(), measure=rms) @test tm.tuning isa Explicit && tm.range ==r && tm.model == first(r) @test input_scitype(tm) == Unknown @test TunedModel(models=r, measure=rms) == tm @@ -54,7 +54,16 @@ r = [m(K) for K in 13:-1:2] TunedModel(tuning=Explicit(), measure=rms)) @test_throws(MLJTuning.ERR_NEED_EXPLICIT, TunedModel(models=r, tuning=Grid())) - tm = TunedModel(model=first(r), range=r, measure=rms) + @test_logs TunedModel(first(r), range=r, measure=rms) + @test_logs( + (:warn, MLJTuning.warn_double_spec(first(r), last(r))), + TunedModel(first(r), model=last(r), range=r, measure=rms), + ) + @test_throws( + MLJTuning.ERR_TOO_MANY_ARGUMENTS, + TunedModel(first(r), last(r), range=r, measure=rms), + ) + tm = @test_logs TunedModel(model=first(r), range=r, measure=rms) @test tm.tuning isa RandomSearch @test input_scitype(tm) == Table(Continuous) end