Skip to content

Commit

Permalink
Merge pull request #175 from JuliaAI/model-as-arg
Browse files Browse the repository at this point in the history
Allow user to specify model in `TunedModel` as arg instead of kwarg
  • Loading branch information
ablaom authored May 24, 2022
2 parents 76b3718 + 0174e93 commit 7a7db1b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
17 changes: 15 additions & 2 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ const ERR_MODEL_TYPE = ArgumentError(
"Only `Deterministic` and `Probabilistic` model types supported.")
const INFO_MODEL_IGNORED =
"`model` being ignored. Using `model=first(range)`. "
const ERR_TOO_MANY_ARGUMENTS =
ArgumentError("At most one non-keyword argument allowed. ")
warn_double_spec(arg, model) =
"Using `model=$arg`. Ignoring keyword specification `model=$model`. "

const ProbabilisticTypes = Union{Probabilistic, MLJBase.MLJModelInterface.ProbabilisticDetector}
const DeterministicTypes = Union{Deterministic, MLJBase.MLJModelInterface.DeterministicDetector}
Expand Down Expand Up @@ -233,7 +237,7 @@ plus other key/value pairs specific to the `tuning` strategy.
likely limited to the case `resampling isa Holdout`.
"""
function TunedModel(; model=nothing,
function TunedModel(args...; model=nothing,
models=nothing,
tuning=nothing,
resampling=MLJBase.Holdout(),
Expand All @@ -254,8 +258,17 @@ function TunedModel(; model=nothing,
check_measure=true,
cache=true)

# user can specify model as argument instead of kwarg:
length(args) < 2 || throw(ERR_TOO_MANY_ARGUMENTS)
if length(args) === 1
arg = first(args)
model === nothing ||
@warn warn_double_spec(arg, model)
model =arg
end

# either `models` is specified and `tuning` is set to `Explicit`,
# or `models` is unspecified and tuning will fallback to `Grid()`
# or `models` is unspecified and tuning will fallback to `RandomSearch()`
# unless it is itself specified:
if models !== nothing
if tuning === nothing
Expand Down
13 changes: 11 additions & 2 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ r = [m(K) for K in 13:-1:2]
@test_throws(MLJTuning.ERR_BOTH_DISALLOWED,
TunedModel(model=first(r),
models=r, tuning=Explicit(), measure=rms))
tm = TunedModel(models=r, tuning=Explicit(), measure=rms)
tm = @test_logs TunedModel(models=r, tuning=Explicit(), measure=rms)
@test tm.tuning isa Explicit && tm.range ==r && tm.model == first(r)
@test input_scitype(tm) == Unknown
@test TunedModel(models=r, measure=rms) == tm
Expand All @@ -54,7 +54,16 @@ r = [m(K) for K in 13:-1:2]
TunedModel(tuning=Explicit(), measure=rms))
@test_throws(MLJTuning.ERR_NEED_EXPLICIT,
TunedModel(models=r, tuning=Grid()))
tm = TunedModel(model=first(r), range=r, measure=rms)
@test_logs TunedModel(first(r), range=r, measure=rms)
@test_logs(
(:warn, MLJTuning.warn_double_spec(first(r), last(r))),
TunedModel(first(r), model=last(r), range=r, measure=rms),
)
@test_throws(
MLJTuning.ERR_TOO_MANY_ARGUMENTS,
TunedModel(first(r), last(r), range=r, measure=rms),
)
tm = @test_logs TunedModel(model=first(r), range=r, measure=rms)
@test tm.tuning isa RandomSearch
@test input_scitype(tm) == Table(Continuous)
end
Expand Down

0 comments on commit 7a7db1b

Please sign in to comment.