Skip to content

Commit

Permalink
add test to catch serialization bug
Browse files Browse the repository at this point in the history
oops
  • Loading branch information
ablaom committed Mar 3, 2024
1 parent b12c0bc commit 3f070fb
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,54 @@ end
@test iteration_parameter(model) == :n
end

# define a supervised model with ephemeral `fitresult`, but which overcomes this by
# overloading `save`/`restore`:
thing = []
mutable struct EphemeralRegressor <: Deterministic
n::Int # dummy iteration parameter
end
EphemeralRegressor(; n=1) = EphemeralRegressor(n)
function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
# if I serialize/deserialized `thing` then `view` below changes:
view = objectid(thing)
fitresult = (thing, view, mean(y))
return fitresult, nothing, NamedTuple()
end
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
thing, view, μ = fitresult
return view == objectid(thing) ? fill(μ, nrows(X)) :
throw(ErrorException("dead fitresult"))
end
MLJBase.iteration_parameter(::EphemeralRegressor) = :n
function MLJBase.save(::EphemeralRegressor, fitresult)
thing, _, μ = fitresult
return (thing, μ)
end
function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
thing, μ = serialized_fitresult
view = objectid(thing)
return (thing, view, μ)
end

@testset "save and restore" begin
#https://github.com/alan-turing-institute/MLJ.jl/issues/1099
X, y = (; x = rand(10)), fill(42.0, 3)
controls = [Step(1), NumberLimit(2)]
imodel = IteratedModel(
EphemeralRegressor(42);
measure=l2,
resampling=Holdout(),
controls,
)
mach = machine(imodel, X, y)
fit!(mach, verbosity=0)
io = IOBuffer()
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
close(io)
@test_broken MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
end

end
true

0 comments on commit 3f070fb

Please sign in to comment.