Skip to content

Commit

Permalink
Merge pull request #77 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.5 release
  • Loading branch information
ablaom authored Sep 15, 2020
2 parents a18a191 + 44b33cc commit 3e989d5
Show file tree
Hide file tree
Showing 16 changed files with 602 additions and 426 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTuning"
uuid = "03970b2e-30c4-11ea-3135-d1576263f10f"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.4.3"
version = "0.5.0"

[deps]
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Expand Down
497 changes: 281 additions & 216 deletions README.md

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/MLJTuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export TunedModel
# defined in strategies/:
export Explicit, Grid, RandomSearch

# defined in selection_heuristics/:
export NaiveSelection

# defined in learning_curves.jl:
export learning_curve!, learning_curve

Expand Down Expand Up @@ -36,6 +39,9 @@ const DEFAULT_N = 10 # for when `default_n` is not implemented

include("utilities.jl")
include("tuning_strategy_interface.jl")
include("selection_heuristics.jl")
include("tuned_models.jl")
include("range_methods.jl")
include("strategies/explicit.jl")
include("strategies/grid.jl")
include("strategies/random_search.jl")
Expand Down
61 changes: 61 additions & 0 deletions src/selection_heuristics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
abstract type SelectionHeuristic end

## HELPERS

measure_adjusted_weights(weights, measures) =
if weights isa Nothing
vcat([signature(measures[1]), ], zeros(length(measures) - 1))
else
length(weights) == length(measures) ||
throw(DimensionMismatch(
"`OptimizeAggregatedMeasurement` heuristic "*
"is being applied to a list of measures whose length "*
"differs from that of the specified `weights`. "))
signature.(measures) .* weights
end


## OPTIMIZE AGGREGATED MEASURE

"""
NaiveSelection(; weights=nothing)
Construct a common selection heuristic for use with `TunedModel` instances
which only considers measurements aggregated over all samples (folds)
in resampling.
For each entry in the tuning history, one defines a penalty equal to
the evaluations of the `measure` specified in the `TunedModel`
instance, aggregated over all samples, and multiplied by `-1` if `measure`
is a `:score`, and `+`` if it is a loss. The heuristic declares as
"best" (optimal) the model whose corresponding entry has the lowest
penalty.
If `measure` is a vector, then the first element is used, unless
per-measure `weights` are explicitly specified. Weights associated
with measures that are neither `:loss` nor `:score` are reset to zero.
"""
struct NaiveSelection <: SelectionHeuristic
weights::Union{Nothing, Vector{Real}}
end
NaiveSelection(; weights=nothing) =
NaiveSelection(weights)


function best(heuristic::NaiveSelection, history)
first_entry = history[1]
measures = first_entry.measure
weights = measure_adjusted_weights(heuristic.weights, measures)
measurements = [weights'*(h.measurement) for h in history]
measure = first(history).measure[1]
if orientation(measure) == :score
measurements = -measurements

end
best_index = argmin(measurements)
return history[best_index]
end

MLJTuning.supports_heuristic(::Any, ::NaiveSelection) =
true
10 changes: 5 additions & 5 deletions src/strategies/explicit.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mutable struct Explicit <: TuningStrategy end

mutable struct ExplicitState{R,N}
struct ExplicitState{R,N}
range::R # a model-generating iterator
next::Union{Nothing,N} # to hold output of `iterate(range)`
end
Expand All @@ -14,7 +14,7 @@ function MLJTuning.setup(tuning::Explicit, model, range, verbosity)
end

# models! returns all available models in the range at once:
function MLJTuning.models!(tuning::Explicit,
function MLJTuning.models(tuning::Explicit,
model,
history,
state,
Expand All @@ -23,7 +23,7 @@ function MLJTuning.models!(tuning::Explicit,

range, next = state.range, state.next

next === nothing && return nothing
next === nothing && return nothing, state

m, s = next
models = [m, ]
Expand All @@ -39,9 +39,9 @@ function MLJTuning.models!(tuning::Explicit,
next = iterate(range, s)
end

state.next = next
new_state = ExplicitState(range, next)

return models
return models, new_state

end

Expand Down
26 changes: 10 additions & 16 deletions src/strategies/grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,16 @@ function setup(tuning::Grid, model, user_range, verbosity)

end

MLJTuning.models!(tuning::Grid,
model,
history,
state,
n_remaining,
verbosity) =
state.models[_length(history) + 1:end]

function tuning_report(tuning::Grid, history, state)

plotting = plotting_report(state.fields, state.parameter_scales, history)

# todo: remove collects?
return (history=history, plotting=plotting)

end
MLJTuning.models(tuning::Grid,
model,
history,
state,
n_remaining,
verbosity) =
state.models[_length(history) + 1:end], state

tuning_report(tuning::Grid, history, state) =
(plotting = plotting_report(state.fields, state.parameter_scales, history),)

function default_n(tuning::Grid, user_range)

Expand Down
17 changes: 9 additions & 8 deletions src/strategies/random_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,21 @@ setup(tuning::RandomSearch, model, user_range, verbosity) =
tuning.positive_unbounded,
tuning.other) |> collect

function MLJTuning.models!(tuning::RandomSearch,
model,
history,
state, # tuple of (field, sampler) pairs
n_remaining,
verbosity)
return map(1:n_remaining) do _
function MLJTuning.models(tuning::RandomSearch,
model,
history,
state, # tuple of (field, sampler) pairs
n_remaining,
verbosity)
new_models = map(1:n_remaining) do _
clone = deepcopy(model)
Random.shuffle!(tuning.rng, state)
for (fld, s) in state
recursive_setproperty!(clone, fld, rand(tuning.rng, s))
end
clone
end
return new_models, state
end

function tuning_report(tuning::RandomSearch, history, field_sampler_pairs)
Expand All @@ -141,6 +142,6 @@ function tuning_report(tuning::RandomSearch, history, field_sampler_pairs)

plotting = plotting_report(fields, parameter_scales, history)

return (history=history, plotting=plotting)
return (plotting=plotting,)

end
Loading

0 comments on commit 3e989d5

Please sign in to comment.