From 104705b35ec1cd5a8392d2a396c315af5c1e7978 Mon Sep 17 00:00:00 2001 From: OkonSamuel Date: Wed, 4 Sep 2024 22:24:51 +0100 Subject: [PATCH 1/2] add feature importances support to iterated models --- src/core.jl | 8 ++++++++ src/traits.jl | 1 + test/core.jl | 13 +++++++++++-- test/traits.jl | 1 + 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/core.jl b/src/core.jl index 6708ab0..24ae536 100644 --- a/src/core.jl +++ b/src/core.jl @@ -155,3 +155,11 @@ MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = # here `fitresult` is a trained atomic machine: MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) + +# Feature importances +function MLJBase.feature_importances(::EitherIteratedModel, fitresult, report) + # fitresult here is the curent state of the iterated machine + # The line below will return `nothing` when the iteration model doesn't + # support feature_importances. + return MLJBase.feature_importances(fitresult) +end \ No newline at end of file diff --git a/src/traits.jl b/src/traits.jl index 6380d41..216a88a 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -15,6 +15,7 @@ for trait in [:supports_weights, :is_pure_julia, :input_scitype, :output_scitype, + :reports_feature_importances, :target_scitype] quote # needed because traits are not always deducable from diff --git a/test/core.jl b/test/core.jl index 1f54945..00311b4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -272,8 +272,10 @@ function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) # if I serialize/deserialized `thing` then `id` below changes: id = objectid(thing) fitresult = (thing, id, mean(y)) - return fitresult, nothing, NamedTuple() + report = (importances = [ftr => 1.0 for ftr in MLJBase.schema(X).names], ) + return fitresult, nothing, report end + function MLJBase.predict(::EphemeralRegressor, fitresult, X) thing, id, μ = fitresult return id == objectid(thing) ? fill(μ, nrows(X)) : @@ -290,7 +292,12 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) return (thing, id, μ) end -@testset "save and restore" begin +MLJBase.reports_feature_importances(::Type{<:EphemeralRegressor}) = true +function MLJBase.feature_importances(::EphemeralRegressor, fitresult, report) + return report.importances +end + +@testset "feature importances, save and restore" begin #https://github.com/JuliaAI/MLJ.jl/issues/1099 X, y = (; x = rand(10)), fill(42.0, 3) controls = [Step(1), NumberLimit(2)] @@ -302,12 +309,14 @@ end ) mach = machine(imodel, X, y) fit!(mach, verbosity=0) + @test MLJBase.feature_importances(mach) == [:x => 1.0]; io = IOBuffer() MLJBase.save(io, mach) seekstart(io) mach2 = machine(io) close(io) @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) + end end diff --git a/test/traits.jl b/test/traits.jl index 341628b..b132f98 100644 --- a/test/traits.jl +++ b/test/traits.jl @@ -23,6 +23,7 @@ imodel = IteratedModel(model=model, measure=mae) @test output_scitype(imodel) == output_scitype(model) @test target_scitype(imodel) == target_scitype(model) @test constructor(imodel) == IteratedModel +@test reports_feature_importances(imodel) == reports_feature_importances(model) end From 8d776fb1a2f4a66ac07c61e1c8d5096ffc32d11e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 6 Sep 2024 08:46:22 +1200 Subject: [PATCH 2/2] bump 0.6.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b477951..86200d1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.6.2" +version = "0.6.3" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"