Skip to content

Commit

Permalink
annotate type for old_model field of Machine type
Browse files Browse the repository at this point in the history
oops
  • Loading branch information
ablaom committed Apr 8, 2024
1 parent f01a03c commit 7ae5821
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
32 changes: 20 additions & 12 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ caches_data_by_default(m) = caches_data_by_default(typeof(m))
caches_data_by_default(::Type) = true
caches_data_by_default(::Type{<:Symbol}) = false

mutable struct Machine{M,C} <: MLJType
mutable struct Machine{M,OM,C} <: MLJType

model::M
old_model # for remembering the model used in last call to `fit!`
old_model::OM # for remembering the model used in last call to `fit!`
fitresult
cache

Expand All @@ -77,8 +77,11 @@ mutable struct Machine{M,C} <: MLJType
function Machine(
model::M, args::AbstractNode...;
cache=caches_data_by_default(model),
) where M
mach = new{M,cache}(model)
) where M
# In the case of symbolic model, machine cannot know the type of model to be fit
# at time of construction:
OM = M == Symbol ? Any : M
mach = new{M,OM,cache}(model)
mach.frozen = false
mach.state = 0
mach.args = args
Expand Down Expand Up @@ -115,7 +118,7 @@ any upstream dependencies in a learning network):
replace(mach, :args => (), :data => (), :data_resampled_data => (), :cache => nothing)
"""
function Base.replace(mach::Machine{<:Any,C}, field_value_pairs::Pair...) where C
function Base.replace(mach::Machine{<:Any,<:Any,C}, field_value_pairs::Pair...) where C
# determined new `model` and `args` and build replacement dictionary:
newfield_given_old = Dict(field_value_pairs) # to be extended
fields_to_be_replaced = keys(newfield_given_old)
Expand Down Expand Up @@ -436,8 +439,8 @@ machines(::Source) = Machine[]

## DISPLAY

_cache_status(::Machine{<:Any,true}) = "caches model-specific representations of data"
_cache_status(::Machine{<:Any,false}) = "does not cache data"
_cache_status(::Machine{<:Any,<:Any,true}) = "caches model-specific representations of data"
_cache_status(::Machine{<:Any,<:Any,false}) = "does not cache data"

function Base.show(io::IO, mach::Machine)
model = mach.model
Expand Down Expand Up @@ -502,8 +505,8 @@ end
# for getting model specific representation of the row-restricted
# training data from a machine, according to the value of the machine
# type parameter `C` (`true` or `false`):
_resampled_data(mach::Machine{<:Any,true}, model, rows) = mach.resampled_data
function _resampled_data(mach::Machine{<:Any,false}, model, rows)
_resampled_data(mach::Machine{<:Any,<:Any,true}, model, rows) = mach.resampled_data
function _resampled_data(mach::Machine{<:Any,<:Any,false}, model, rows)
raw_args = map(N -> N(), mach.args)
data = MMI.reformat(model, raw_args...)
return selectrows(model, rows, data...)
Expand All @@ -518,6 +521,10 @@ err_no_real_model(mach) = ErrorException(
"""
)

err_missing_model(model) = ErrorException(
"Specified `composite` model does not have `:$(model)` as a field."
)

"""
last_model(mach::Machine)
Expand Down Expand Up @@ -605,7 +612,7 @@ more on these lower-level training methods.
"""
function fit_only!(
mach::Machine{<:Any,cache_data};
mach::Machine{<:Any,<:Any,cache_data};
rows=nothing,
verbosity=1,
force=false,
Expand All @@ -628,7 +635,8 @@ function fit_only!(
# `getproperty(composite, mach.model)`:
model = if mach.model isa Symbol
isnothing(composite) && throw(err_no_real_model(mach))
mach.model in propertynames(composite)
mach.model in propertynames(composite) ||
throw(err_missing_model(model))
getproperty(composite, mach.model)
else
mach.model
Expand Down Expand Up @@ -967,7 +975,7 @@ A machine returned by `serializable` is characterized by the property
See also [`restore!`](@ref), [`MLJBase.save`](@ref).
"""
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C
function serializable(mach::Machine{<:Any,<:Any,C}, model=mach.model; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach
Expand Down
10 changes: 6 additions & 4 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ for operation in OPERATIONS
operation == :inverse_transform && continue

ex = quote
function $(operation)(mach::Machine{<:Model,false}; rows=:)
function $(operation)(mach::Machine{<:Model,<:Any,false}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && throw(err_serialized($operation))
return ($operation)(mach, mach.args[1](rows=rows))
end
function $(operation)(mach::Machine{<:Model,true}; rows=:)
function $(operation)(mach::Machine{<:Model,<:Any,true}; rows=:)
# catch deserialized machine with no data:
isempty(mach.args) && throw(err_serialized($operation))
model = last_model(mach)
Expand All @@ -92,8 +92,10 @@ for operation in OPERATIONS
end

# special case of Static models (no training arguments):
$operation(mach::Machine{<:Static,true}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,false}; rows=:) = throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,<:Any,true}; rows=:) =
throw(ERR_ROWS_NOT_ALLOWED)
$operation(mach::Machine{<:Static,<:Any,false}; rows=:) =
throw(ERR_ROWS_NOT_ALLOWED)
end
eval(ex)

Expand Down
2 changes: 1 addition & 1 deletion src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ end
@static if VERSION >= v"1.3.0-DEV.573"

# determines if an instantiated machine caches data:
_caches_data(::Machine{M, C}) where {M, C} = C
_caches_data(::Machine{<:Any,<:Any,C}) where C = C

function _evaluate!(func, mach, accel::CPUThreads, nfolds, verbosity)

Expand Down
1 change: 0 additions & 1 deletion test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ end
X = ones(2, 3)

mach = @test_logs machine(Scale(2))
@test mach isa Machine{Scale, false}
transform(mach, X) # triggers training of `mach`, ie is mutating
@test report(mach) in [nothing, NamedTuple()]
@test isnothing(fitted_params(mach))
Expand Down

0 comments on commit 7ae5821

Please sign in to comment.