Skip to content

Commit

Permalink
Merge pull request #172 from JuliaAI/class-weights
Browse files Browse the repository at this point in the history
Add class weight support
  • Loading branch information
ablaom authored May 24, 2022
2 parents 212bd5e + 36e444b commit 76b3718
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 23 deletions.
72 changes: 49 additions & 23 deletions src/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ mutable struct DeterministicTunedModel{T,M<:DeterministicTypes} <: MLJBase.Deter
tuning::T # tuning strategy
resampling # resampling strategy
measure
weights::Union{Nothing,Vector{<:Real}}
weights::Union{Nothing,AbstractVector{<:Real}}
class_weights::Union{Nothing,AbstractDict}
operation
range
selection_heuristic
Expand All @@ -49,6 +50,7 @@ mutable struct ProbabilisticTunedModel{T,M<:ProbabilisticTypes} <: MLJBase.Proba
resampling # resampling strategy
measure
weights::Union{Nothing,AbstractVector{<:Real}}
class_weights::Union{Nothing,AbstractDict}
operation
range
selection_heuristic
Expand Down Expand Up @@ -115,6 +117,8 @@ Calling `fit!(mach)` on a machine `mach=machine(tuned_model, X, y)` or
internal machine. The final train can be supressed by setting
`train_best=false`.
### Search space
The `range` objects supported depend on the `tuning` strategy
specified. Query the `strategy` docstring for details. To optimize
over an explicit list `v` of models of the same type, use
Expand All @@ -125,28 +129,26 @@ then `MLJTuning.default_n(tuning, range)` is used. When `n` is
increased and `fit!(mach)` called again, the old search history is
re-instated and the search continues where it left off.
If `measure` supports weights (`supports_weights(measure) == true`)
then any `weights` specified will be passed to the measure. If more
than one `measure` is specified, then only the first is optimized
(unless `strategy` is multi-objective) but the performance against
every measure specified will be computed and reported in
`report(mach).best_performance` and other relevant attributes of the
generated report.
### Measures (metrics)
Specify `repeats > 1` for repeated resampling per model
evaluation. See [`evaluate!`](@ref) options for details.
If more than one `measure` is specified, then only the first is
optimized (unless `strategy` is multi-objective) but the performance
against every measure specified will be computed and reported in
`report(mach).best_performance` and other relevant attributes of the
generated report. Options exist to pass per-observation weights or
class weights to measures; see below.
*Important.* If a custom `measure` is used, and the measure is
a score, rather than a loss, be sure to check that
`MLJ.orientation(measure) == :score` to ensure maximization of the
*Important.* If a custom measure, `my_measure` is used, and the
measure is a score, rather than a loss, be sure to check that
`MLJ.orientation(my_measure) == :score` to ensure maximization of the
measure, rather than minimization. Override an incorrect value with
`MLJ.orientation(::typeof(measure)) = :score`.
`MLJ.orientation(::typeof(my_measure)) = :score`.
### Accessing the fitted parameters and other training (tuning) outcomes
A Plots.jl plot of performance estimates is returned by `plot(mach)`
or `heatmap(mach)`.
### Accessing the fitted parameters and other training (tuning) outcomes
Once a tuning machine `mach` has bee trained as above, then
`fitted_params(mach)` has these keys/values:
Expand All @@ -166,7 +168,7 @@ key | value
plus other key/value pairs specific to the `tuning` strategy.
### Summary of key-word arguments
### Complete list of key-word options
- `model`: `Supervised` model prototype that is cloned and mutated to
generate models for evaluation
Expand All @@ -186,11 +188,15 @@ plus other key/value pairs specific to the `tuning` strategy.
evaluations; only the first used in optimization (unless the
strategy is multi-objective) but all reported to the history
- `weights`: sample weights to be passed the measure(s) in performance
evaluations, if supported.
- `weights`: per-observation weights to be passed the measure(s) in performance
evaluations, where supported. Check support with `supports_weights(measure)`.
- `class_weights`: class weights to be passed the measure(s) in
performance evaluations, where supported. Check support with
`supports_class_weights(measure)`.
- `repeats=1`: for generating train/test sets multiple times in
resampling; see [`evaluate!`](@ref) for details
resampling ("Monte Carlo" resampling); see [`evaluate!`](@ref) for details
- `operation`/`operations` - One of
$(MLJBase.PREDICT_OPERATIONS_STRING), or a vector of these of the
Expand Down Expand Up @@ -234,6 +240,7 @@ function TunedModel(; model=nothing,
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operations=nothing,
operation=operations,
ranges=nothing,
Expand Down Expand Up @@ -296,9 +303,24 @@ function TunedModel(; model=nothing,
# get the tuning type parameter:
T = typeof(tuning)

args = (model, tuning, resampling, measure, weights, operation, range,
selection_heuristic, train_best, repeats, n, acceleration, acceleration_resampling,
check_measure, cache)
args = (
model,
tuning,
resampling,
measure,
weights,
class_weights,
operation,
range,
selection_heuristic,
train_best,
repeats,
n,
acceleration,
acceleration_resampling,
check_measure,
cache
)

if M <: DeterministicTypes
tuned_model = DeterministicTunedModel{T,M}(args...)
Expand Down Expand Up @@ -532,6 +554,7 @@ function assemble_events!(metamodels,
resampling = resampling_machine.model.resampling,
measure = resampling_machine.model.measure,
weights = resampling_machine.model.weights,
class_weights = resampling_machine.model.class_weights,
operation = resampling_machine.model.operation,
check_measure = resampling_machine.model.check_measure,
repeats = resampling_machine.model.repeats,
Expand Down Expand Up @@ -694,6 +717,7 @@ function MLJBase.fit(tuned_model::EitherTunedModel{T,M},
resampling = deepcopy(tuned_model.resampling),
measure = tuned_model.measure,
weights = tuned_model.weights,
class_weights = tuned_model.class_weights,
operation = tuned_model.operation,
check_measure = tuned_model.check_measure,
repeats = tuned_model.repeats,
Expand Down Expand Up @@ -785,6 +809,8 @@ end
MLJBase.is_wrapper(::Type{<:EitherTunedModel}) = true
MLJBase.supports_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
MLJBase.supports_weights(M)
MLJBase.supports_class_weights(::Type{<:EitherTunedModel{<:Any,M}}) where M =
MLJBase.supports_class_weights(M)
MLJBase.load_path(::Type{<:ProbabilisticTunedModel}) =
"MLJTuning.ProbabilisticTunedModel"
MLJBase.load_path(::Type{<:DeterministicTunedModel}) =
Expand Down
44 changes: 44 additions & 0 deletions test/tuned_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,50 @@ end

end

@testset_accelerated "weights and class_weights are being passed" accel begin
# we'll be tuning using 50/50 holdout
X = (x=fill(1.0, 6),)
y = coerce(["a", "a", "b", "a", "a", "b"], OrderedFactor)
w = [1.0, 1.0, 100.0, 1.0, 1.0, 100.0]
class_w = Dict("a" => 2.0, "b" => 100.0)

model = DecisionTreeClassifier()

# the first supports weights, the second class weights:
ms=[MisclassificationRate(), MulticlassFScore()]

resampling=Holdout(fraction_train=0.5)

# without weights:
tmodel = TunedModel(
resampling=resampling,
models=fill(model, 5),
measures=ms,
acceleration=accel
)
mach = machine(tmodel, X, y)
fit!(mach, verbosity=0)
measurement = report(mach).best_history_entry.measurement
e = evaluate(model, X, y, measures=ms, resampling=resampling)
@test measurement == e.measurement

# with weights:
tmodel.weights = w
tmodel.class_weights = class_w
fit!(mach, verbosity=0)
measurement_weighted = report(mach).best_history_entry.measurement
e_weighted = evaluate(model, X, y;
measures=ms,
resampling=resampling,
weights=w,
class_weights=class_w,
verbosity=-1)
@test measurement_weighted == e_weighted.measurement

# check both measures are different when they are weighted:
@test !any(measurement .== measurement_weighted)
end

@testset "data caching at outer level suppressed" begin
X, y = make_blobs()
model = DecisionTreeClassifier()
Expand Down

0 comments on commit 76b3718

Please sign in to comment.