From 3f070fbea5696da23a6e5592339bab3a94079b91 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 12:21:15 +1300 Subject: [PATCH 1/4] add test to catch serialization bug oops --- test/core.jl | 49 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/test/core.jl b/test/core.jl index 4e75ead..566b5ac 100644 --- a/test/core.jl +++ b/test/core.jl @@ -239,5 +239,54 @@ end @test iteration_parameter(model) == :n end +# define a supervised model with ephemeral `fitresult`, but which overcomes this by +# overloading `save`/`restore`: +thing = [] +mutable struct EphemeralRegressor <: Deterministic + n::Int # dummy iteration parameter +end +EphemeralRegressor(; n=1) = EphemeralRegressor(n) +function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) + # if I serialize/deserialized `thing` then `view` below changes: + view = objectid(thing) + fitresult = (thing, view, mean(y)) + return fitresult, nothing, NamedTuple() +end +function MLJBase.predict(::EphemeralRegressor, fitresult, X) + thing, view, μ = fitresult + return view == objectid(thing) ? fill(μ, nrows(X)) : + throw(ErrorException("dead fitresult")) +end +MLJBase.iteration_parameter(::EphemeralRegressor) = :n +function MLJBase.save(::EphemeralRegressor, fitresult) + thing, _, μ = fitresult + return (thing, μ) +end +function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) + thing, μ = serialized_fitresult + view = objectid(thing) + return (thing, view, μ) +end + +@testset "save and restore" begin + #https://github.com/alan-turing-institute/MLJ.jl/issues/1099 + X, y = (; x = rand(10)), fill(42.0, 3) + controls = [Step(1), NumberLimit(2)] + imodel = IteratedModel( + EphemeralRegressor(42); + measure=l2, + resampling=Holdout(), + controls, + ) + mach = machine(imodel, X, y) + fit!(mach, verbosity=0) + io = IOBuffer() + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + close(io) + @test_broken MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) +end + end true From e901f4a33fef286caa810eed1384b70752011fd4 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 12:31:02 +1300 Subject: [PATCH 2/4] fix serialization, resolving part of MLJ.jl issue 1099 --- src/core.jl | 4 ++++ test/core.jl | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/core.jl b/src/core.jl index df5f88b..6708ab0 100644 --- a/src/core.jl +++ b/src/core.jl @@ -151,3 +151,7 @@ MLJBase.predict(::EitherIteratedModel, fitresult, Xnew) = MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) = transform(fitresult, Xnew) + +# here `fitresult` is a trained atomic machine: +MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult) +MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult) diff --git a/test/core.jl b/test/core.jl index 566b5ac..f7dccde 100644 --- a/test/core.jl +++ b/test/core.jl @@ -285,7 +285,7 @@ end seekstart(io) mach2 = machine(io) close(io) - @test_broken MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) end end From 4e324106e49da4d9d00bfa5b1f1b8c475c31ebf2 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Mon, 4 Mar 2024 12:48:48 +1300 Subject: [PATCH 3/4] bump 0.6.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 797cac5..2234904 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.6.0" +version = "0.6.1" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" From 4db6d5930a3ac30f7db6843694082d15666bbf81 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 7 Mar 2024 08:08:08 +1300 Subject: [PATCH 4/4] view -> id --- test/core.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/core.jl b/test/core.jl index f7dccde..2119da4 100644 --- a/test/core.jl +++ b/test/core.jl @@ -247,14 +247,14 @@ mutable struct EphemeralRegressor <: Deterministic end EphemeralRegressor(; n=1) = EphemeralRegressor(n) function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) - # if I serialize/deserialized `thing` then `view` below changes: - view = objectid(thing) - fitresult = (thing, view, mean(y)) + # if I serialize/deserialized `thing` then `id` below changes: + id = objectid(thing) + fitresult = (thing, id, mean(y)) return fitresult, nothing, NamedTuple() end function MLJBase.predict(::EphemeralRegressor, fitresult, X) - thing, view, μ = fitresult - return view == objectid(thing) ? fill(μ, nrows(X)) : + thing, id, μ = fitresult + return id == objectid(thing) ? fill(μ, nrows(X)) : throw(ErrorException("dead fitresult")) end MLJBase.iteration_parameter(::EphemeralRegressor) = :n @@ -264,8 +264,8 @@ function MLJBase.save(::EphemeralRegressor, fitresult) end function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) thing, μ = serialized_fitresult - view = objectid(thing) - return (thing, view, μ) + id = objectid(thing) + return (thing, id, μ) end @testset "save and restore" begin