From 864eb6392e7bcb06eb5401389dc2eddda8aea2e7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 22 Dec 2024 19:44:26 -0500 Subject: [PATCH 1/6] fix: make `:class` col more generic to `X` type --- src/MLJInterface.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 3a8d0a1d..8d2a0045 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -212,12 +212,10 @@ function _update( options, class, ) - if isnothing(class) && MMI.istable(X) && haskey(X, :class) - if !(X isa NamedTuple) - error("Classes can only be specified with named tuples.") - end - new_X = Base.structdiff(X, (; X.class)) - new_class = X.class + if isnothing(class) && MMI.istable(X) && :class in MMI.schema(X).names + names_without_class = filter(!=(:class), MMI.schema(X).names) + new_X = MMI.selectcols(X, collect(names_without_class)) + new_class = MMI.selectcols(X, :class) return _update( m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_class ) From 01e3bfc28862c012733228fc8ccbc2489e1c39b0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 22 Dec 2024 19:49:53 -0500 Subject: [PATCH 2/6] feat: introduce new trait for special class column --- src/InterfaceDynamicExpressions.jl | 3 +++ src/MLJInterface.jl | 20 +++++++++++++------- src/ParametricExpression.jl | 9 ++++++--- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index b3095706..8b79230a 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -359,4 +359,7 @@ function DE.EvaluationHelpersModule._grad_evaluator( ) end +# Allows special handling of class columns in MLJInterface.jl +handles_class_column(::Type{<:AbstractExpression}) = false + end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 8d2a0045..d29fc278 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -25,6 +25,7 @@ using DynamicQuantities: dimension using LossFunctions: SupervisedLoss using ..InterfaceDynamicQuantitiesModule: get_dimensions_type +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS @@ -212,7 +213,10 @@ function _update( options, class, ) - if isnothing(class) && MMI.istable(X) && :class in MMI.schema(X).names + if IDE.handles_class_column(m.expression_type) && + isnothing(class) && + MMI.istable(X) && + :class in MMI.schema(X).names names_without_class = filter(!=(:class), MMI.schema(X).names) new_X = MMI.selectcols(X, collect(names_without_class)) new_class = MMI.selectcols(X, :class) @@ -486,12 +490,14 @@ function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegress ) return _predict(m, fitresult, Xnew.data, Xnew.idx, class) end - if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class) - if !(Xnew isa NamedTuple) - error("Classes can only be specified with named tuples.") - end - Xnew2 = Base.structdiff(Xnew, (; Xnew.class)) - return _predict(m, fitresult, Xnew2, idx, Xnew.class) + if IDE.handles_class_column(m.expression_type) && + isnothing(class) && + MMI.istable(Xnew) && + :class in MMI.schema(Xnew).names + names_without_class = filter(!=(:class), MMI.schema(Xnew).names) + Xnew2 = MMI.selectcols(Xnew, collect(names_without_class)) + class = MMI.selectcols(Xnew, :class) + return _predict(m, fitresult, Xnew2, idx, class) end if fitresult.has_class diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index a5664fd4..5b92db8c 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -17,7 +17,7 @@ using Random: default_rng, AbstractRNG using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, AbstractMutationWeights using ..PopMemberModule: PopMember -using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..LossFunctionsModule: LossFunctionsModule as LF using ..ExpressionBuilderModule: ExpressionBuilderModule as EB using ..MutateModule: MutateModule as MM @@ -65,7 +65,7 @@ function DE.eval_tree_array( options::AbstractOptions; kws..., ) - A = expected_array_type(X, typeof(tree)) + A = IDE.expected_array_type(X, typeof(tree)) out, complete = DE.eval_tree_array( tree, X, @@ -80,7 +80,7 @@ end function LF.eval_tree_dispatch( tree::ParametricExpression, dataset::Dataset, options::AbstractOptions, idx ) - A = expected_array_type(dataset.X, typeof(tree)) + A = IDE.expected_array_type(dataset.X, typeof(tree)) out, complete = DE.eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), @@ -181,4 +181,7 @@ function MF.mutate_constant( end end +# ParametricExpression handles class columns +IDE.handles_class_column(::Type{<:ParametricExpression}) = true + end From bb5eada8f34b665db2d0716fef27fbea6695ff58 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 22 Dec 2024 20:04:59 -0500 Subject: [PATCH 3/6] style: clean up indentation --- src/MLJInterface.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index d29fc278..46fe47ea 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -213,10 +213,12 @@ function _update( options, class, ) - if IDE.handles_class_column(m.expression_type) && + if ( + IDE.handles_class_column(m.expression_type) && isnothing(class) && MMI.istable(X) && :class in MMI.schema(X).names + ) names_without_class = filter(!=(:class), MMI.schema(X).names) new_X = MMI.selectcols(X, collect(names_without_class)) new_class = MMI.selectcols(X, :class) @@ -490,10 +492,12 @@ function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegress ) return _predict(m, fitresult, Xnew.data, Xnew.idx, class) end - if IDE.handles_class_column(m.expression_type) && + if ( + IDE.handles_class_column(m.expression_type) && isnothing(class) && MMI.istable(Xnew) && :class in MMI.schema(Xnew).names + ) names_without_class = filter(!=(:class), MMI.schema(Xnew).names) Xnew2 = MMI.selectcols(Xnew, collect(names_without_class)) class = MMI.selectcols(Xnew, :class) From 21443d49d74881f4896abf1d6ba79d887177f45b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 20 Dec 2024 22:21:02 -0500 Subject: [PATCH 4/6] feat: conditionally widen MLJ scitype --- src/MLJInterface.jl | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 46fe47ea..bf36c29c 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -33,6 +33,7 @@ using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore using ..LoggingModule: AbstractSRLogger +using ..TemplateExpressionModule: TemplateExpression import ..equation_search @@ -51,8 +52,9 @@ end """Generate an `SRRegressor` struct containing all the fields in `Options`.""" function modelexpr(model_name::Symbol) - struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <: - AbstractSRRegressor + struct_def = :(Base.@kwdef mutable struct $(model_name){ + D<:AbstractDimensions,L,E<:AbstractExpression + } <: AbstractSRRegressor niterations::Int = 100 parallelism::Symbol = :multithreading numprocs::Union{Int,Nothing} = nothing @@ -63,7 +65,7 @@ function modelexpr(model_name::Symbol) logger::Union{AbstractSRLogger,Nothing} = nothing runtests::Bool = true run_id::Union{String,Nothing} = nothing - loss_type::L = Nothing + loss_type::Type{L} = Nothing selection_method::Function = choose_best dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} end) @@ -72,7 +74,14 @@ function modelexpr(model_name::Symbol) # Add everything from `Options` constructor directly to struct: for (i, option) in enumerate(DEFAULT_OPTIONS) + if getsymb(first(option.args)) == :expression_type + continue + end insert!(fields, i, Expr(:(=), option.args...)) + if getsymb(first(option.args)) == :node_type + # Manually add `expression_type` above, so it can be depended on by `node_type` + insert!(fields, i - 1, :(expression_type::Type{E} = Expression)) + end end # We also need to create the `get_options` function, based on this: @@ -605,7 +614,7 @@ const input_scitype = Union{ MMI.metadata_model( SRRegressor; input_scitype, - target_scitype=AbstractVector{<:Any}, + target_scitype=AbstractVector{<:MMI.Continuous}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor", @@ -614,13 +623,22 @@ MMI.metadata_model( MMI.metadata_model( MultitargetSRRegressor; input_scitype, - target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}}, + target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor", human_name="Multi-Target Symbolic Regression via Evolutionary Search", ) +function MMI.target_scitype(::Type{<:SRRegressor{<:Any,<:Any,<:TemplateExpression}}) + return AbstractVector{<:Any} +end +function MMI.target_scitype( + ::Type{<:MultitargetSRRegressor{<:Any,<:Any,<:TemplateExpression}} +) + return Union{MMI.Table(Any),AbstractMatrix{<:Any}} +end + function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String) docstring = """$(MMI.doc_header(eval(model_name))) From ea27c1a2e08cb84b87cba6e5067ddd65bef3702d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 20 Dec 2024 22:29:10 -0500 Subject: [PATCH 5/6] fix: ambiguity in target scitype --- src/MLJInterface.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index bf36c29c..1bba4ebc 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -630,12 +630,14 @@ MMI.metadata_model( human_name="Multi-Target Symbolic Regression via Evolutionary Search", ) -function MMI.target_scitype(::Type{<:SRRegressor{<:Any,<:Any,<:TemplateExpression}}) +function MMI.target_scitype( + ::Type{<:SRRegressor{D,L,E}} +) where {D<:AbstractDimensions,L,E<:TemplateExpression} return AbstractVector{<:Any} end function MMI.target_scitype( - ::Type{<:MultitargetSRRegressor{<:Any,<:Any,<:TemplateExpression}} -) + ::Type{<:MultitargetSRRegressor{D,L,E}} +) where {D<:AbstractDimensions,L,E<:TemplateExpression} return Union{MMI.Table(Any),AbstractMatrix{<:Any}} end From 9f82c13666a326d4fd560382001abb2ccda59102 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 21 Dec 2024 16:29:23 -0500 Subject: [PATCH 6/6] fix: switch to `Unknown` rather than `Any` --- src/MLJInterface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 1bba4ebc..2c37c458 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -633,12 +633,12 @@ MMI.metadata_model( function MMI.target_scitype( ::Type{<:SRRegressor{D,L,E}} ) where {D<:AbstractDimensions,L,E<:TemplateExpression} - return AbstractVector{<:Any} + return AbstractVector{<:MMI.Unknown} end function MMI.target_scitype( ::Type{<:MultitargetSRRegressor{D,L,E}} ) where {D<:AbstractDimensions,L,E<:TemplateExpression} - return Union{MMI.Table(Any),AbstractMatrix{<:Any}} + return Union{MMI.Table(MMI.Unknown),AbstractMatrix{<:MMI.Unknown}} end function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)