Skip to content

Commit

Permalink
Merge pull request #960 from JuliaAI/nested-serialization
Browse files Browse the repository at this point in the history
Fix problem with serialization of nested models when component model overload `save`/`restore`
  • Loading branch information
ablaom authored Mar 1, 2024
2 parents 831abfa + f7ef4fe commit b23bbd6
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.1.1"
version = "1.1.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
24 changes: 14 additions & 10 deletions src/composition/learning_networks/replace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ the `model` and `args` field values as derived from the provided dictionaries. I
the returned machine is hooked into the new learning network defined by the values of
`newnode_given_old`.
If `serializable=true`, return a serializable copy instead (namely,
`serializable(node.mach)`) and ignore the `newmodel_given_old` dictionary (no model
replacement).
If `serializable=true`, return a serializable copy instead, but make no model replacement.
The `newmodel_given_old` dictionary is still used, but now to look up the concrete model
corresponding to the symbolic one stored in `node`'s machine.
See also [`serializable`](@ref).
Expand All @@ -26,9 +26,10 @@ function machine_replacement(
newnode_given_old,
serializable
)
# the `replace` called here is defined in src/machines.jl:
mach = serializable ? MLJBase.serializable(N.machine) :
replace(N.machine, :model => newmodel_given_old[N.machine.model])
# the `replace` called below is defined in src/machines.jl.
newmodel = newmodel_given_old[N.machine.model]
mach = serializable ? MLJBase.serializable(N.machine, newmodel) :
replace(N.machine, :model => newmodel)
mach.args = Tuple(newnode_given_old[arg] for arg in N.machine.args)
return mach
end
Expand All @@ -38,6 +39,7 @@ end
newnode_given_old,
newmach_given_old,
newmodel_given_old,
serializable,
node::AbstractNode)
**Private method.**
Expand Down Expand Up @@ -86,9 +88,11 @@ const DOC_REPLACE_OPTIONS =
- `copy_unspecified_deeply=true`: If `false`, models or sources not listed for
replacement are identically equal in the original and returned node.
- `serializable=false`: If `true`, all machines in the new network are serializable.
However, all `model` replacements are ignored, and unspecified sources are always
replaced with empty ones.
- `serializable=false`: If `true`, all machines in the new network are made
serializable and the specified model replacements are only used for serialization
purposes: for each pair `s => model` (`s` assumed to be a symbolic model) each
machine with model `s` is replaced with `serializable(mach, model)`. All unspecified
sources are always replaced with empty ones.
"""

Expand Down Expand Up @@ -192,7 +196,7 @@ function _replace(

# Instantiate model dictionary:
model_pairs = filter(collect(pairs)) do pair
first(pair) isa Model
first(pair) isa Model || first(pair) isa Symbol
end
models_ = models(W)
models_to_copy = setdiff(models_, first.(model_pairs))
Expand Down
25 changes: 20 additions & 5 deletions src/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,33 @@ MLJModelInterface.fitted_params(composite::NetworkComposite, signature) =
MLJModelInterface.reporting_operations(::Type{<:NetworkComposite}) = OPERATIONS

# here `fitresult` has type `Signature`.
save(model::NetworkComposite, fitresult) = replace(fitresult, serializable=true)
function save(model::NetworkComposite, fitresult)
# The network includes machines with symbolic models. These machines need to be
# replaced by serializable versions, but we cannot naively use `serializable(mach)`,
# because the absence of the concrete model means this just returns `mach` (because
# `save(::Symbol, fitresult)` returns `fitresult`). We need to use the special
# `serialiable(mach, model)` instead. This is what `replace` below does, because we
# pass it the flag `serializable=true` but we must also pass `symbol =>
# concrete_model` replacements, which we calculate first:

greatest_lower_bound = MLJBase.glb(fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
atomic_models = keys(machines_given_model)
pairs = [atom => getproperty(model, atom) for atom in atomic_models]

replace(fitresult, pairs...; serializable=true)
end

function MLJModelInterface.restore(model::NetworkComposite, serializable_fitresult)
greatest_lower_bound = MLJBase.glb(serializable_fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
models = keys(machines_given_model)
atomic_models = keys(machines_given_model)

# the following indirectly mutates `serialiable_fiteresult`, returning it to
# usefulness:
for model in models
for mach in machines_given_model[model]
mach.fitresult = restore(model, mach.fitresult)
for atom in atomic_models
for mach in machines_given_model[atom]
mach.fitresult = MLJBase.restore(getproperty(model, atom), mach.fitresult)
mach.state = 1
end
end
Expand Down
14 changes: 7 additions & 7 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,17 +975,17 @@ A machine returned by `serializable` is characterized by the property
See also [`restore!`](@ref), [`MLJBase.save`](@ref).
"""
function serializable(mach::Machine{<:Any, C}; verbosity=1) where C
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

# The next line of code makes `serializable` recursive, in the case that `mach.model`
# is a `Composite` model: `save` duplicates the underlying learning network, which
# involves calls to `serializable` on the old machines in the network to create the
# new ones.

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

serializable_fitresult = save(mach.model, mach.fitresult)
serializable_fitresult = save(model, mach.fitresult)

# Duplication currenty needs to happen in two steps for this to work in case of
# `Composite` models.
Expand Down Expand Up @@ -1017,9 +1017,9 @@ useable form.
For an example see [`serializable`](@ref).
"""
function restore!(mach::Machine)
function restore!(mach::Machine, model=mach.model)
mach.state != -1 && return mach
mach.fitresult = restore(mach.model, mach.fitresult)
mach.fitresult = restore(model, mach.fitresult)
mach.state = 1
return mach
end
Expand Down
74 changes: 74 additions & 0 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,80 @@ end
rm(filename)
end

# define a model with non-persistent fitresult:
thing = []
struct EphemeralTransformer <: Unsupervised end
function MLJModelInterface.fit(::EphemeralTransformer, verbosity, X)
view = pointer(thing)
fitresult = (thing, view)
return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.transform(::EphemeralTransformer, fitresult, X)
thing, view = fitresult
return view == pointer(thing) ? X : throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralTransformer, fitresult)
thing, _ = fitresult
return thing
end
function MLJModelInterface.restore(::EphemeralTransformer, serialized_fitresult)
view = pointer(thing)
return (thing, view)
end

# commented out code just tests the transformer above has desired properties for testing:

# # test model transforms:
# model = EphemeralTransformer()
# mach = machine(model, 42) |> fit!
# @test MLJBase.transform(mach, 27) == 27

# # direct serialization fails:
# io = IOBuffer()
# serialize(io, mach)
# seekstart(io)
# mach2 = deserialize(io)
# @test_throws ErrorException("dead fitresult") transform(mach2, 42)

@testset "serialization for model with non-persistent fitresult" begin
X = (; x=randn(5))
mach = machine(EphemeralTransformer(), X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

@testset "serialization for model with non-persistent fitresult in pipeline" begin
# https://github.com/JuliaAI/MLJBase.jl/issues/927
X = (; x=randn(5))
pipe = Standardizer |> EphemeralTransformer
X = (; x=randn(5))
mach = machine(pipe, X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

struct ReportingDynamic <: Unsupervised end
MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple()
MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,))
Expand Down
34 changes: 17 additions & 17 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,22 @@ API.@trait(
[LogLoss(), ], dummy_interval, 1))
end

@everywhere begin
nfolds = 6
nmeasures = 2
func(mach, k) = (
(sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)),
:fitted_params,
:report,
)
end
@testset_accelerated "dispatch of resources and progress meter" accel begin

@info "Checking progress bars:"

X = (x = [1, ],)
y = [2.0, ]

@everywhere begin
nfolds = 6
nmeasures = 2
func(mach, k) = (
(sleep(MLJBase.PROG_METER_DT*rand(rng)); fill(1:k, nmeasures)),
:fitted_params,
:report,
)
end
mach = machine(ConstantRegressor(), X, y)
if accel isa CPUThreads
result = MLJBase._evaluate!(
Expand Down Expand Up @@ -643,15 +643,15 @@ end

struct DummyResamplingStrategy <: MLJBase.ResamplingStrategy end

@testset_accelerated "custom strategy depending on X, y" accel begin
function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy,
rows, X, y)
train = filter(rows) do j
y[j] == y[1]
function MLJBase.train_test_pairs(resampling::DummyResamplingStrategy,
rows, X, y)
train = filter(rows) do j
y[j] == y[1]
end
test = setdiff(rows, train)
return [(train, test),]
end
test = setdiff(rows, train)
return [(train, test),]
end
@testset_accelerated "custom strategy depending on X, y" accel begin

X = (x = rand(rng,8), )
y = categorical(string.([:x, :y, :x, :x, :y, :x, :x, :y]))
Expand Down

0 comments on commit b23bbd6

Please sign in to comment.