Skip to content

Commit

Permalink
restore and save optsum for GLMM
Browse files Browse the repository at this point in the history
  • Loading branch information
palday committed Nov 5, 2024
1 parent 9622f36 commit 6f14225
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 48 deletions.
6 changes: 6 additions & 0 deletions src/optsummary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,9 @@ function _check_nlopt_return(ret, failure_modes=_NLOPT_FAILURE_MODES)
@warn("NLopt optimization failure: $ret")
end
end

function Base.:(==)(o1::OptSummary{T}, o2::OptSummary{T}) where {T}
return all(fieldnames(OptSummary)) do fn
return getfield(o1, fn) == getfield(o2, fn)
end
end
105 changes: 62 additions & 43 deletions src/serialization.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,66 @@
"""
restoreoptsum!(m::LinearMixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::LinearMixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::MixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::MixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
Read, check, and restore the `optsum` field from a JSON stream or filename.
"""
function restoreoptsum!(
m::LinearMixedModel{T}, io::IO; atol::Real=zero(T),
rtol::Real=atol > 0 ? zero(T) : eps(T),
) where {T}
function restoreoptsum!(m::MixedModel, filename; kwargs...)
return open(filename, "r") do io
return restoreoptsum!(m, io; kwargs...)
end
end

function restoreoptsum!(m::LinearMixedModel{T}, io::IO;
atol::Real=zero(T),
rtol::Real=atol > 0 ? zero(T) : eps(T)) where {T}
dict = JSON3.read(io)
ops = restoreoptsum!(m.optsum, dict)
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
if !isapprox(
objective(updateL!(setθ!(m, getfield(ops, par)))), getfield(ops, obj_at_par); rtol, atol
)
throw(ArgumentError("model m at $par does not give stored $obj_at_par within given tolerances"))
end
end
return m
end

function restoreoptsum!(m::GeneralizedLinearMixedModel{T}, io::IO;
atol::Real=zero(T),
rtol::Real=atol > 0 ? zero(T) : eps(T)) where {T}
dict = JSON3.read(io)
ops = m.optsum

# need to accommodate fast and slow fits
resize!(ops.initial, length(dict.initial))
resize!(ops.final, length(dict.final))

theta_beta_len = length(m.θ) + length(m.β)
if length(dict.initial) == theta_beta_len # fast=false
if length(ops.lowerbd) == length(m.θ)
prepend!(ops.lowerbd, fill(-Inf, length(m.β)))
end
setpar! = setβθ!
varyβ = false
else # fast=true
setpar! = setθ!
varyβ = true
if length(ops.lowerbd) != length(m.θ)
deleteat!(ops.lowerbd, 1:length(m.β))
end
end
restoreoptsum!(ops, dict)
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
if !isapprox(
deviance(pirls!(setpar!(m, getfield(ops, par)), varyβ), dict.nAGQ), getfield(ops, obj_at_par); rtol, atol
)
throw(ArgumentError("model m at $par does not give stored $obj_at_par within given tolerances"))
end
end
return m
end

function restoreoptsum!(ops::OptSummary{T}, dict::AbstractDict) where {T}
allowed_missing = (
:lowerbd, # never saved, -Inf not allowed in JSON
:xtol_zero_abs, # added in v4.25.0
Expand All @@ -28,16 +79,6 @@ function restoreoptsum!(
@warn "optsum was saved with an older version of MixedModels.jl: consider resaving."
end

# GLMM case with fast=slow
theta_beta_len = length(m.θ) + length(m.β)
if length(dict.initial) == theta_beta_len
resize!(ops.initial, theta_beta_len)
resize!(ops.final, theta_beta_len)
if length(ops.lowerbd) == length(m.θ)
prepend!(ops.lowerbd, fill(-Inf, length(m.β)))
end
end

if any(ops.lowerbd .> dict.initial) || any(ops.lowerbd .> dict.final)
@debug "" ops.lowerbd
@debug "" dict.initial
Expand All @@ -51,13 +92,6 @@ function restoreoptsum!(
ops.xtol_rel = copy(dict.xtol_rel)
copyto!(ops.initial, dict.initial)
copyto!(ops.final, dict.final)
for (v, f) in (:initial => :finitial, :final => :fmin)
if !isapprox(
objective(updateL!(setθ!(m, getfield(ops, v)))), getfield(ops, f); rtol, atol
)
throw(ArgumentError("model m at $v does not give stored $f"))
end
end
ops.optimizer = Symbol(dict.optimizer)
ops.returnvalue = Symbol(dict.returnvalue)
# compatibility with fits saved before the introduction of various extensions
Expand All @@ -73,38 +107,23 @@ function restoreoptsum!(
else
[(convert(Vector{T}, first(entry)), T(last(entry))) for entry in fitlog]
end
return m
end

function restoreoptsum!(m::LinearMixedModel{T}, filename; kwargs...) where {T}
open(filename, "r") do io
restoreoptsum!(m, io; kwargs...)
end
end

function restoreoptsum!(m::GeneralizedLinearMixedModel, fname; kwargs...)
restoreoptsum!(m.LMM, fname; kwargs...)
deviance!(m)
return m
return ops
end

"""
saveoptsum(io::IO, m::LinearMixedModel)
saveoptsum(filename, m::LinearMixedModel)
saveoptsum(io::IO, m::MixedModel)
saveoptsum(filename, m::MixedModel)
Save `m.optsum` (w/o the `lowerbd` field) in JSON format to an IO stream or a file
The reason for omitting the `lowerbd` field is because it often contains `-Inf`
values that are not allowed in JSON.
"""
saveoptsum(io::IO, m::LinearMixedModel) = JSON3.write(io, m.optsum)
function saveoptsum(filename, m::LinearMixedModel)
saveoptsum(io::IO, m::MixedModel) = JSON3.write(io, m.optsum)
function saveoptsum(filename, m::MixedModel)
open(filename, "w") do io
saveoptsum(io, m)
end
end

function saveoptsum(io, m::GeneralizedLinearMixedModel)
return saveoptsum(io, m.LMM)
end
# TODO, maybe: something nice for the MixedModelBootstrap
19 changes: 15 additions & 4 deletions test/pirls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,24 @@ end
cbpp = dataset(:cbpp)
gm_original = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
gm_restored = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
fit!(gm_original; progress=false, nAGQ=1)

io = IOBuffer()

fit!(gm_original; progress=false, nAGQ=1)
saveoptsum(io, gm_original)
saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) deviance(gm_restored)

refit!(gm_original; progress=false, nAGQ=3)
saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
# println(read(seekstart(io), String))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) deviance(gm_restored)

save
refit!(gm_original; progress=false, fast=true)
saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) deviance(gm_restored)
end
2 changes: 1 addition & 1 deletion test/pls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ end
fm_mod = deepcopy(fm)
fm_mod.optsum.fmin += 1
saveoptsum(seekstart(io), fm_mod)
@test_throws(ArgumentError("model m at final does not give stored fmin"),
@test_throws(ArgumentError("model m at final does not give stored fmin within given tolerances"),
restoreoptsum!(m, seekstart(io)))
restoreoptsum!(m, seekstart(io); atol=1)
@test m.optsum.fmin - fm.optsum.fmin 1
Expand Down

0 comments on commit 6f14225

Please sign in to comment.