Skip to content

Commit

Permalink
Merge pull request #392 from MilesCranmer/narrow-mlj-scitype
Browse files Browse the repository at this point in the history
Broaden MLJ `target_scitype` only when using `TemplateExpression`
  • Loading branch information
MilesCranmer authored Jan 3, 2025
2 parents f73670f + 9f82c13 commit 6368a95
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
3 changes: 3 additions & 0 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,7 @@ function DE.EvaluationHelpersModule._grad_evaluator(
)
end

# Allows special handling of class columns in MLJInterface.jl
handles_class_column(::Type{<:AbstractExpression}) = false

end
62 changes: 45 additions & 17 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ using DynamicQuantities:
dimension
using LossFunctions: SupervisedLoss
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE
using ..CoreModule:
Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
using ..ComplexityModule: compute_complexity
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
using ..UtilsModule: subscriptify, @ignore
using ..LoggingModule: AbstractSRLogger
using ..TemplateExpressionModule: TemplateExpression

import ..equation_search

Expand All @@ -50,8 +52,9 @@ end

"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr(model_name::Symbol)
struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <:
AbstractSRRegressor
struct_def = :(Base.@kwdef mutable struct $(model_name){
D<:AbstractDimensions,L,E<:AbstractExpression
} <: AbstractSRRegressor
niterations::Int = 100
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
Expand All @@ -62,7 +65,7 @@ function modelexpr(model_name::Symbol)
logger::Union{AbstractSRLogger,Nothing} = nothing
runtests::Bool = true
run_id::Union{String,Nothing} = nothing
loss_type::L = Nothing
loss_type::Type{L} = Nothing
selection_method::Function = choose_best
dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
end)
Expand All @@ -71,7 +74,14 @@ function modelexpr(model_name::Symbol)

# Add everything from `Options` constructor directly to struct:
for (i, option) in enumerate(DEFAULT_OPTIONS)
if getsymb(first(option.args)) == :expression_type
continue
end
insert!(fields, i, Expr(:(=), option.args...))
if getsymb(first(option.args)) == :node_type
# Manually add `expression_type` above, so it can be depended on by `node_type`
insert!(fields, i - 1, :(expression_type::Type{E} = Expression))
end
end

# We also need to create the `get_options` function, based on this:
Expand Down Expand Up @@ -212,12 +222,15 @@ function _update(
options,
class,
)
if isnothing(class) && MMI.istable(X) && haskey(X, :class)
if !(X isa NamedTuple)
error("Classes can only be specified with named tuples.")
end
new_X = Base.structdiff(X, (; X.class))
new_class = X.class
if (
IDE.handles_class_column(m.expression_type) &&
isnothing(class) &&
MMI.istable(X) &&
:class in MMI.schema(X).names
)
names_without_class = filter(!=(:class), MMI.schema(X).names)
new_X = MMI.selectcols(X, collect(names_without_class))
new_class = MMI.selectcols(X, :class)
return _update(
m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_class
)
Expand Down Expand Up @@ -488,12 +501,16 @@ function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegress
)
return _predict(m, fitresult, Xnew.data, Xnew.idx, class)
end
if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class)
if !(Xnew isa NamedTuple)
error("Classes can only be specified with named tuples.")
end
Xnew2 = Base.structdiff(Xnew, (; Xnew.class))
return _predict(m, fitresult, Xnew2, idx, Xnew.class)
if (
IDE.handles_class_column(m.expression_type) &&
isnothing(class) &&
MMI.istable(Xnew) &&
:class in MMI.schema(Xnew).names
)
names_without_class = filter(!=(:class), MMI.schema(Xnew).names)
Xnew2 = MMI.selectcols(Xnew, collect(names_without_class))
class = MMI.selectcols(Xnew, :class)
return _predict(m, fitresult, Xnew2, idx, class)
end

if fitresult.has_class
Expand Down Expand Up @@ -597,7 +614,7 @@ const input_scitype = Union{
MMI.metadata_model(
SRRegressor;
input_scitype,
target_scitype=AbstractVector{<:Any},
target_scitype=AbstractVector{<:MMI.Continuous},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
Expand All @@ -606,13 +623,24 @@ MMI.metadata_model(
MMI.metadata_model(
MultitargetSRRegressor;
input_scitype,
target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}},
target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
human_name="Multi-Target Symbolic Regression via Evolutionary Search",
)

function MMI.target_scitype(
::Type{<:SRRegressor{D,L,E}}
) where {D<:AbstractDimensions,L,E<:TemplateExpression}
return AbstractVector{<:MMI.Unknown}
end
function MMI.target_scitype(
::Type{<:MultitargetSRRegressor{D,L,E}}
) where {D<:AbstractDimensions,L,E<:TemplateExpression}
return Union{MMI.Table(MMI.Unknown),AbstractMatrix{<:MMI.Unknown}}
end

function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)
docstring = """$(MMI.doc_header(eval(model_name)))
Expand Down
9 changes: 6 additions & 3 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Random: default_rng, AbstractRNG

using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights
using ..PopMemberModule: PopMember
using ..InterfaceDynamicExpressionsModule: expected_array_type
using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE
using ..LossFunctionsModule: LossFunctionsModule as LF
using ..ExpressionBuilderModule: ExpressionBuilderModule as EB
using ..MutateModule: MutateModule as MM
Expand Down Expand Up @@ -65,7 +65,7 @@ function DE.eval_tree_array(
options::AbstractOptions;
kws...,
)
A = expected_array_type(X, typeof(tree))
A = IDE.expected_array_type(X, typeof(tree))
out, complete = DE.eval_tree_array(
tree,
X,
Expand All @@ -80,7 +80,7 @@ end
function LF.eval_tree_dispatch(
tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx
)
A = expected_array_type(dataset.X, typeof(tree))
A = IDE.expected_array_type(dataset.X, typeof(tree))
out, complete = DE.eval_tree_array(
tree,
LF.maybe_getindex(dataset.X, :, idx),
Expand Down Expand Up @@ -181,4 +181,7 @@ function MF.mutate_constant(
end
end

# ParametricExpression handles class columns
IDE.handles_class_column(::Type{<:ParametricExpression}) = true

end

0 comments on commit 6368a95

Please sign in to comment.